提交 a0acfc6a 编写于 作者: Z zchen0211

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: { case framework::AttrType::BOOLEAN: {
return attr_desc.b(); return attr_desc.b();
...@@ -61,13 +61,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) { ...@@ -61,13 +61,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
} }
return val; return val;
} }
case framework::AttrType::BLOCK: { default:
PADDLE_ENFORCE(program != nullptr, PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
"Need to specify ProgramDesc when get a block attr");
return program->mutable_blocks(attr_desc.block_idx());
}
} }
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
return boost::blank(); return boost::blank();
} }
......
...@@ -32,7 +32,7 @@ inline AttrType AttrTypeID() { ...@@ -32,7 +32,7 @@ inline AttrType AttrTypeID() {
return static_cast<AttrType>(tmp.which() - 1); return static_cast<AttrType>(tmp.which() - 1);
} }
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc); Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
class AttrReader { class AttrReader {
public: public:
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <deque> #include <deque>
#include <list> #include <list>
#include <memory> #include <memory>
#include <unordered_set>
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
...@@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names, ...@@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names,
return true; return true;
} }
static std::string FwdName(const std::string& grad_name) {
auto pos = grad_name.find("@GRAD");
if (pos == std::string::npos) {
return "";
} else {
return grad_name.substr(0, pos);
}
}
static void CreateGradVarInBlock( static void CreateGradVarInBlock(
size_t grad_op_start_index, size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map, const std::unordered_map<std::string, std::string>& param_name_map,
...@@ -294,6 +304,7 @@ static void CreateGradVarInBlock( ...@@ -294,6 +304,7 @@ static void CreateGradVarInBlock(
for (size_t op_index = grad_op_start_index; op_index < ops.size(); for (size_t op_index = grad_op_start_index; op_index < ops.size();
++op_index) { ++op_index) {
bool need_infer_shape = false; bool need_infer_shape = false;
std::unordered_set<std::string> new_vars;
ForEachVarName(ops[op_index]->Outputs(), ForEachVarName(ops[op_index]->Outputs(),
[&](const std::string& grad_var_name) { [&](const std::string& grad_var_name) {
if (block_desc->HasVar(grad_var_name)) { if (block_desc->HasVar(grad_var_name)) {
...@@ -301,8 +312,7 @@ static void CreateGradVarInBlock( ...@@ -301,8 +312,7 @@ static void CreateGradVarInBlock(
} }
need_infer_shape = true; need_infer_shape = true;
auto var = block_desc->Var(grad_var_name); auto var = block_desc->Var(grad_var_name);
// FIXME(qiao) infer the datatype new_vars.insert(var->Name());
var->SetDataType(framework::DataType::FP32);
auto it = param_name_map.find(grad_var_name); auto it = param_name_map.find(grad_var_name);
if (it == param_name_map.end()) { if (it == param_name_map.end()) {
return false; return false;
...@@ -316,6 +326,21 @@ static void CreateGradVarInBlock( ...@@ -316,6 +326,21 @@ static void CreateGradVarInBlock(
}); });
if (need_infer_shape) { if (need_infer_shape) {
ops[op_index]->InferVarType(block_desc); ops[op_index]->InferVarType(block_desc);
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
if (new_vars.find(arg) == new_vars.end()) {
continue;
}
auto pname = FwdName(arg);
auto* param = block_desc->FindVar(pname);
auto* grad = block_desc->FindVar(arg);
if (param == nullptr) {
LOG(WARNING) << "Cannot find forward variable of " << arg
<< ". Set its gradient to FP32";
grad->SetDataType(DataType::FP32);
} else {
grad->SetDataType(param->GetDataType());
}
}
ops[op_index]->InferShape(*block_desc); ops[op_index]->InferShape(*block_desc);
} }
} }
...@@ -368,7 +393,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -368,7 +393,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
BlockDescBind* cur_block = program_desc.Block(block_idx); BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
std::vector<OpDescBind*> op_descs = cur_block->AllOps(); std::vector<OpDescBind*> op_descs = cur_block->AllOps();
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0; size_t grad_desc_idx = 0;
...@@ -443,7 +468,7 @@ ParamGradInfoMap AppendBackward( ...@@ -443,7 +468,7 @@ ParamGradInfoMap AppendBackward(
} }
const int root_block_idx = 0; const int root_block_idx = 0;
auto root_block = program_desc.Block(root_block_idx); auto root_block = program_desc.MutableBlock(root_block_idx);
// insert fill one op for target // insert fill one op for target
// TODO(qiao) add some check to the target. // TODO(qiao) add some check to the target.
...@@ -492,7 +517,7 @@ ParamGradInfoMap AppendBackward( ...@@ -492,7 +517,7 @@ ParamGradInfoMap AppendBackward(
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv); CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
for (size_t block_index = forward_block_num; for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) { block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index), CreateGradVarInBlock(0, grad_to_var, program_desc.MutableBlock(block_index),
&retv); &retv);
} }
return retv; return retv;
......
...@@ -499,7 +499,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -499,7 +499,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
TEST(Backward, simple_single_op) { TEST(Backward, simple_single_op) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op = block->AppendOp(); f::OpDescBind *op = block->AppendOp();
op->SetType("rowwise_add"); op->SetType("rowwise_add");
...@@ -535,7 +535,7 @@ TEST(Backward, simple_single_op) { ...@@ -535,7 +535,7 @@ TEST(Backward, simple_single_op) {
TEST(Backward, default_attribute) { TEST(Backward, default_attribute) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op = block->AppendOp(); f::OpDescBind *op = block->AppendOp();
op->SetType("mul"); op->SetType("mul");
op->SetInput("X", {"x"}); op->SetInput("X", {"x"});
...@@ -561,7 +561,7 @@ TEST(Backward, default_attribute) { ...@@ -561,7 +561,7 @@ TEST(Backward, default_attribute) {
TEST(Backward, simple_mult_op) { TEST(Backward, simple_mult_op) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
...@@ -644,7 +644,7 @@ TEST(Backward, simple_mult_op) { ...@@ -644,7 +644,7 @@ TEST(Backward, simple_mult_op) {
TEST(Backward, intermedia_var_no_grad) { TEST(Backward, intermedia_var_no_grad) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
...@@ -714,7 +714,7 @@ TEST(Backward, intermedia_var_no_grad) { ...@@ -714,7 +714,7 @@ TEST(Backward, intermedia_var_no_grad) {
TEST(Backward, var_no_grad) { TEST(Backward, var_no_grad) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("mult_in_out"); op1->SetType("mult_in_out");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
...@@ -790,7 +790,7 @@ TEST(Backward, var_no_grad) { ...@@ -790,7 +790,7 @@ TEST(Backward, var_no_grad) {
TEST(Backward, shared_var) { TEST(Backward, shared_var) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
...@@ -880,7 +880,7 @@ TEST(Backward, shared_var) { ...@@ -880,7 +880,7 @@ TEST(Backward, shared_var) {
TEST(Backward, half_backward) { TEST(Backward, half_backward) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
auto *op1 = block->AppendOp(); auto *op1 = block->AppendOp();
op1->SetType("minus"); op1->SetType("minus");
op1->SetInput("X", {"a"}); op1->SetInput("X", {"a"});
......
...@@ -113,7 +113,7 @@ BlockDescBind *BlockDescBind::ParentBlock() const { ...@@ -113,7 +113,7 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
if (this->desc_->parent_idx() == kNoneBlockIndex) { if (this->desc_->parent_idx() == kNoneBlockIndex) {
return nullptr; return nullptr;
} }
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx())); return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
} }
BlockDesc *BlockDescBind::Proto() { BlockDesc *BlockDescBind::Proto() {
......
...@@ -73,33 +73,32 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) { ...@@ -73,33 +73,32 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
} }
} }
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication) // - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id); PADDLE_ENFORCE_LT(block_id, pdesc.Size());
auto& block = pdesc.blocks(block_id); auto& block = pdesc.Block(block_id);
auto& device = device_contexts_[0]; auto& device = device_contexts_[0];
Scope& local_scope = scope->NewScope(); Scope& local_scope = scope->NewScope();
for (auto& var : block.vars()) { for (auto& var : block.AllVars()) {
if (var.persistable()) { if (var->Persistable()) {
auto* ptr = scope->Var(var.name()); auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var.type()); CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var.name() VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr; << " global, which pointer is " << ptr;
} else { } else {
auto* ptr = local_scope.Var(var.name()); auto* ptr = local_scope.Var(var->Name());
CreateTensor(ptr, var.type()); CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var.name() VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr; << " locally, which pointer is " << ptr;
} }
} }
for (auto& op_desc : block.ops()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp( auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
op_desc, const_cast<ProgramDesc*>(&pdesc));
op->Run(local_scope, *device); op->Run(local_scope, *device);
} }
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_info.h" #include "paddle/framework/op_info.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
...@@ -34,7 +34,7 @@ class Executor { ...@@ -34,7 +34,7 @@ class Executor {
* ProgramDesc * ProgramDesc
* Scope * Scope
*/ */
void Run(const ProgramDesc&, Scope*, int); void Run(const ProgramDescBind&, Scope*, int);
private: private:
std::vector<platform::DeviceContext*> device_contexts_; std::vector<platform::DeviceContext*> device_contexts_;
......
...@@ -52,6 +52,22 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -52,6 +52,22 @@ class CompileTimeInferShapeContext : public InferShapeContext {
const std::vector<std::string> &Outputs( const std::vector<std::string> &Outputs(
const std::string &name) const override; const std::string &name) const override;
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != VarDesc::LOD_TENSOR) {
VLOG(3) << "input " << in << "is not LodTensor";
return;
}
PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR,
"The %d-th output of Output(%s) must be LoDTensor.", j,
out);
in_var->SetLoDLevel(out_var->GetLodLevel());
}
private: private:
DDim GetDim(const std::string &name) const override; DDim GetDim(const std::string &name) const override;
...@@ -98,7 +114,12 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog) ...@@ -98,7 +114,12 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
// restore attrs_ // restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) { for (const OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name(); std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr, prog->Proto()); if (attr.type() != AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr);
} else {
auto bid = attr.block_idx();
attrs_[attr_name] = prog->MutableBlock(bid);
}
} }
} }
...@@ -172,8 +193,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { ...@@ -172,8 +193,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
} }
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.Proto(); this->attrs_[name] = &block;
this->attrs_[name] = desc;
need_update_ = true; need_update_ = true;
} }
...@@ -192,7 +212,7 @@ Attribute OpDescBind::GetAttr(const std::string &name) const { ...@@ -192,7 +212,7 @@ Attribute OpDescBind::GetAttr(const std::string &name) const {
int OpDescBind::GetBlockAttr(const std::string &name) const { int OpDescBind::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDesc *>(it->second)->idx(); return boost::get<BlockDescBind *>(it->second)->ID();
} }
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap() const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
......
...@@ -43,13 +43,15 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( ...@@ -43,13 +43,15 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val; return ret_val;
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc, std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
ProgramDesc* program) { VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) "
"instead.";
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs; AttributeMap attrs;
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr, program); attrs[attr.name()] = GetAttrValue(attr);
} }
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
......
...@@ -77,8 +77,7 @@ class OpRegistry { ...@@ -77,8 +77,7 @@ class OpRegistry {
const VariableNameMap& outputs, const VariableNameMap& outputs,
AttributeMap attrs); AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc, static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
ProgramDesc* program);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
......
...@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
...@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
...@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
...@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope; paddle::framework::Scope scope;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
......
...@@ -351,6 +351,20 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -351,6 +351,20 @@ class RuntimeInferShapeContext : public InferShapeContext {
return op_.Outputs(name); return op_.Outputs(name);
} }
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
Variable* in_var = scope_.FindVar(Inputs(in)[i]);
Variable* out_var = scope_.FindVar(Outputs(out)[j]);
if (!in_var->IsType<LoDTensor>()) return;
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
}
private: private:
DDim GetDim(const std::string& name) const override { DDim GetDim(const std::string& name) const override {
Variable* var = scope_.FindVar(name); Variable* var = scope_.FindVar(name);
......
...@@ -83,7 +83,7 @@ TEST(OperatorBase, all) { ...@@ -83,7 +83,7 @@ TEST(OperatorBase, all) {
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope.Var("OUT1"); scope.Var("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context); op->Run(scope, device_context);
...@@ -208,7 +208,7 @@ TEST(OpKernel, all) { ...@@ -208,7 +208,7 @@ TEST(OpKernel, all) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
...@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) { ...@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) {
scope.Var("y0")->GetMutable<LoDTensor>(); scope.Var("y0")->GetMutable<LoDTensor>();
scope.Var("y1")->GetMutable<LoDTensor>(); scope.Var("y1")->GetMutable<LoDTensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
} }
......
...@@ -37,7 +37,9 @@ class ProgramDescBind { ...@@ -37,7 +37,9 @@ class ProgramDescBind {
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } BlockDescBind *MutableBlock(size_t idx) { return blocks_[idx].get(); }
const BlockDescBind &Block(size_t idx) const { return *blocks_[idx]; }
size_t Size() const { return blocks_.size(); } size_t Size() const { return blocks_.size(); }
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
TEST(ProgramDesc, copy_ctor) { TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program; ProgramDescBind program;
auto* global_block = program.Block(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR); x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
...@@ -44,7 +44,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -44,7 +44,7 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program_copy(program); ProgramDescBind program_copy(program);
auto* global_block_copy = program_copy.Block(0); auto* global_block_copy = program_copy.MutableBlock(0);
ASSERT_NE(global_block, global_block_copy); ASSERT_NE(global_block, global_block_copy);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
...@@ -82,7 +82,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -82,7 +82,7 @@ TEST(ProgramDesc, copy_ctor) {
TEST(ProgramDescBind, serialize_and_deserialize) { TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin; ProgramDescBind program_origin;
auto* global_block = program_origin.Block(0); auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR); x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
...@@ -108,7 +108,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -108,7 +108,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
program_origin.Proto()->SerializeToString(&binary_str); program_origin.Proto()->SerializeToString(&binary_str);
ProgramDescBind program_restored(binary_str); ProgramDescBind program_restored(binary_str);
auto* global_block_restored = program_restored.Block(0); auto* global_block_restored = program_restored.MutableBlock(0);
ASSERT_NE(global_block, global_block_restored); ASSERT_NE(global_block, global_block_restored);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
......
...@@ -52,7 +52,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -52,7 +52,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
TEST(Prune, one_operator) { TEST(Prune, one_operator) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block); AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
...@@ -69,7 +69,7 @@ TEST(Prune, one_operator) { ...@@ -69,7 +69,7 @@ TEST(Prune, one_operator) {
TEST(Prune, forward) { TEST(Prune, forward) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block); AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block); AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block);
...@@ -88,7 +88,7 @@ TEST(Prune, forward) { ...@@ -88,7 +88,7 @@ TEST(Prune, forward) {
TEST(Prune, multi_input_op) { TEST(Prune, multi_input_op) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block); AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block);
AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block); AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block);
...@@ -106,7 +106,7 @@ TEST(Prune, multi_input_op) { ...@@ -106,7 +106,7 @@ TEST(Prune, multi_input_op) {
TEST(Prune, multi_output_op) { TEST(Prune, multi_output_op) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block); AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block); AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
...@@ -122,7 +122,7 @@ TEST(Prune, multi_output_op) { ...@@ -122,7 +122,7 @@ TEST(Prune, multi_output_op) {
TEST(Prune, multi_target) { TEST(Prune, multi_target) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block); AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block); AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
......
...@@ -28,9 +28,6 @@ void InferShapeContext::SetOutputsDim( ...@@ -28,9 +28,6 @@ void InferShapeContext::SetOutputsDim(
SetDims(names, dims); SetDims(names, dims);
} }
void InferShapeContext::ShareLoD(const std::string &in, const std::string &out,
size_t i, size_t j) const {}
std::vector<framework::DDim> InferShapeContext::GetDims( std::vector<framework::DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const { const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret; std::vector<framework::DDim> ret;
......
...@@ -43,9 +43,8 @@ class InferShapeContext { ...@@ -43,9 +43,8 @@ class InferShapeContext {
virtual const std::vector<std::string> &Outputs( virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0; const std::string &name) const = 0;
// TODO(qiao) implement this function virtual void ShareLoD(const std::string &in, const std::string &out,
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, size_t i = 0, size_t j = 0) const = 0;
size_t j = 0) const;
protected: protected:
virtual framework::DDim GetDim(const std::string &name) const = 0; virtual framework::DDim GetDim(const std::string &name) const = 0;
......
...@@ -36,7 +36,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -36,7 +36,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*>; std::vector<bool>, BlockDescBind*>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -63,41 +63,43 @@ namespace framework { ...@@ -63,41 +63,43 @@ namespace framework {
TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op) {
ProgramDescBind prog; ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
prog.Block(0)->Var("test_a")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_b")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_c")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_out"); prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.Block(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::SELECTED_ROWS, prog.Block(0)->Var("test_out")->GetType()); ASSERT_EQ(VarDesc::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.Block(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR); prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR);
op->InferVarType(prog.Block(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::LOD_TENSOR, prog.Block(0)->Var("test_out")->GetType()); ASSERT_EQ(VarDesc::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
} }
TEST(InferVarType, sum_op_without_infer_var_type) { TEST(InferVarType, sum_op_without_infer_var_type) {
ProgramDescBind prog; ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum_without_infer_var_type"); op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
prog.Block(0)->Var("test2_a")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_b")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_c")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_out"); prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.Block(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc_VarType_LOD_TENSOR, ASSERT_EQ(VarDesc_VarType_LOD_TENSOR,
prog.Block(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
} // namespace framework } // namespace framework
......
...@@ -51,7 +51,7 @@ class RNNAlgorithmTestHelper : public ::testing::Test { ...@@ -51,7 +51,7 @@ class RNNAlgorithmTestHelper : public ::testing::Test {
CreateGlobalVariables(); CreateGlobalVariables();
auto op_desc = CreateOpDesc(); auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); op = paddle::framework::OpRegistry::CreateOp(op_desc);
dop = &(dynamic_cast<DynamicRecurrentOp*>(op.get())->rnn); dop = &(dynamic_cast<DynamicRecurrentOp*>(op.get())->rnn);
InitCacheManually(); InitCacheManually();
InitStepNet(); InitStepNet();
......
...@@ -45,14 +45,14 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -45,14 +45,14 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GaussianRandomOp should not be null."); "Output(Out) of GaussianRandomOp should not be null.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dims"); auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
std::vector<int64_t> temp; std::vector<int64_t> temp;
temp.reserve(dims.size()); temp.reserve(shape.size());
for (auto dim : dims) { for (auto dim : shape) {
temp.push_back(static_cast<int64_t>(dim)); temp.push_back(static_cast<int64_t>(dim));
} }
PADDLE_ENFORCE(dims.size() > 0UL, PADDLE_ENFORCE(shape.size() > 0UL,
"dims can be one int or array. dims must be set."); "shape can be one int or array. shape must be set.");
ctx->SetOutputDim("Out", framework::make_ddim(temp)); ctx->SetOutputDim("Out", framework::make_ddim(temp));
} }
...@@ -74,7 +74,7 @@ GaussianRandom operator. ...@@ -74,7 +74,7 @@ GaussianRandom operator.
Use to initialize tensor with gaussian random generator. Use to initialize tensor with gaussian random generator.
)DOC"); )DOC");
AddAttr<std::vector<int>>("dims", "The dimension of random tensor."); AddAttr<std::vector<int>>("shape", "The dimension of random tensor.");
AddAttr<float>("mean", "mean of random tensor.").SetDefault(.0f); AddAttr<float>("mean", "mean of random tensor.").SetDefault(.0f);
AddAttr<float>("std", "std of random tensor.").SetDefault(1.0f); AddAttr<float>("std", "std of random tensor.").SetDefault(1.0f);
AddAttr<int>("seed", AddAttr<int>("seed",
......
...@@ -43,7 +43,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -43,7 +43,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type()); return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
} }
}; };
...@@ -93,7 +93,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -93,7 +93,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type()); return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
} }
}; };
......
...@@ -61,16 +61,16 @@ template <typename T> ...@@ -61,16 +61,16 @@ template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel<T> { class LookupTableCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); auto* table_t = context.Input<LoDTensor>("W");
auto ids_t = context.Input<Tensor>("Ids"); auto* ids_t = context.Input<LoDTensor>("Ids");
auto output_t = context.Output<Tensor>("Out"); auto* output_t = context.Output<LoDTensor>("Out");
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
size_t K = ids_t->numel(); size_t K = ids_t->numel();
auto ids = ids_t->data<int64_t>(); auto* ids = ids_t->data<int64_t>();
auto table = table_t->data<T>(); auto* table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto* output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
...@@ -87,9 +87,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -87,9 +87,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<Tensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<Tensor>("W"); auto* table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>(); auto* ids_data = ids->data<int64_t>();
...@@ -116,12 +116,12 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -116,12 +116,12 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto* d_output_data = d_output->data<T>(); auto* d_output_data = d_output->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims()); PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data, memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
d_output->numel(), stream); d_output->numel() * sizeof(T), stream);
} else { } else {
auto ids_t = context.Input<Tensor>("Ids"); auto ids_t = context.Input<LoDTensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out")); auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W")); auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0]; int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1]; int D = d_table_t->dims()[1];
......
...@@ -19,22 +19,22 @@ ...@@ -19,22 +19,22 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); // float tensor auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto ids_t = context.Input<Tensor>("Ids"); // int tensor auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto output_t = context.Output<Tensor>("Out"); // float tensor auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
int N = table_t->dims()[0]; int N = table_t->dims()[0];
int D = table_t->dims()[1]; int D = table_t->dims()[1];
auto ids = ids_t->data<int64_t>(); auto* ids = ids_t->data<int64_t>();
auto table = table_t->data<T>(); auto* table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) { for (int64_t i = 0; i < ids_t->numel(); ++i) {
PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
...@@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<Tensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<Tensor>("W"); auto* table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>(); auto* ids_data = ids->data<int64_t>();
...@@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims()); PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else { } else {
auto* ids = context.Input<Tensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<Tensor>(framework::GradVarName("W")); auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto* table = context.Input<Tensor>("W"); auto* table = context.Input<LoDTensor>("W");
auto* ids_data = ids->data<int64_t>(); auto* ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims(); auto ids_dim = ids->dims();
......
...@@ -89,7 +89,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel { ...@@ -89,7 +89,7 @@ class SequenceConvGradOp : public framework::OperatorWithKernel {
} }
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD(framework::GradVarName("X"), "X"); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
if (ctx->HasOutput(framework::GradVarName("Filter"))) { if (ctx->HasOutput(framework::GradVarName("Filter"))) {
ctx->SetOutputDim(framework::GradVarName("Filter"), ctx->SetOutputDim(framework::GradVarName("Filter"),
......
...@@ -129,7 +129,8 @@ void BindProgramDesc(py::module &m) { ...@@ -129,7 +129,8 @@ void BindProgramDesc(py::module &m) {
} }
return retv; return retv;
}) })
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("block", &ProgramDescBind::MutableBlock,
py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size) .def("num_blocks", &ProgramDescBind::Size)
.def("serialize_to_string", .def("serialize_to_string",
[](ProgramDescBind &program_desc) -> py::bytes { [](ProgramDescBind &program_desc) -> py::bytes {
......
...@@ -275,7 +275,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -275,7 +275,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::vector<std::array<size_t, 2>> &targets) { const std::vector<std::array<size_t, 2>> &targets) {
ProgramDescBind prog_with_targets(origin); ProgramDescBind prog_with_targets(origin);
for (const auto &t : targets) { for (const auto &t : targets) {
prog_with_targets.Block(t[0])->Op(t[1])->MarkAsTarget(); prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
} }
ProgramDesc pruned_desc; ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc); Prune(*prog_with_targets.Proto(), &pruned_desc);
...@@ -335,7 +335,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -335,7 +335,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc, nullptr); return OpRegistry::CreateOp(desc);
}) })
.def("backward", .def("backward",
[](const OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
...@@ -439,7 +439,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -439,7 +439,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc, nullptr); auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release()); return static_cast<operators::RecurrentOp *>(rnn_op.release());
}) })
.def("set_stepnet", [](operators::RecurrentOp &self, .def("set_stepnet", [](operators::RecurrentOp &self,
...@@ -457,7 +457,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -457,7 +457,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc, nullptr); auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::DynamicRecurrentOp *>( return static_cast<operators::DynamicRecurrentOp *>(
rnn_op.release()); rnn_op.release());
}) })
...@@ -484,7 +484,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -484,7 +484,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto cond_op = OpRegistry::CreateOp(desc, nullptr); auto cond_op = OpRegistry::CreateOp(desc);
return static_cast<operators::CondOp *>(cond_op.release()); return static_cast<operators::CondOp *>(cond_op.release());
}) })
.def("set_truenet", .def("set_truenet",
...@@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>()) .def(py::init<std::vector<platform::Place> &>())
.def("run", [](Executor &self, ProgramDescBind *program_bind, .def("run", &Executor::Run);
Scope *scope, int block_id) {
self.Run(*program_bind->Proto(), scope, block_id);
});
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("init_gflags", InitGflags); m.def("init_gflags", InitGflags);
......
...@@ -62,7 +62,7 @@ class ConstantInitializer(Initializer): ...@@ -62,7 +62,7 @@ class ConstantInitializer(Initializer):
class UniformInitializer(Initializer): class UniformInitializer(Initializer):
"""Implements for random uniform distribution initializer """Implements the random uniform distribution initializer
""" """
def __init__(self, low=-1.0, high=1.0, seed=0): def __init__(self, low=-1.0, high=1.0, seed=0):
...@@ -75,6 +75,7 @@ class UniformInitializer(Initializer): ...@@ -75,6 +75,7 @@ class UniformInitializer(Initializer):
""" """
assert low is not None assert low is not None
assert high is not None assert high is not None
assert high >= low
assert seed is not None assert seed is not None
super(UniformInitializer, self).__init__() super(UniformInitializer, self).__init__()
self._low = low self._low = low
...@@ -107,3 +108,51 @@ class UniformInitializer(Initializer): ...@@ -107,3 +108,51 @@ class UniformInitializer(Initializer):
}) })
var.op = op var.op = op
return op return op
class NormalInitializer(Initializer):
"""Implements the random Normal(Gaussian) distribution initializer
"""
def __init__(self, loc=0.0, scale=1.0, seed=0):
"""Constructor for NormalInitializer
Args:
loc: mean of the normal distribution
scale: standard deviation of the normal distribution
seed: random seed
"""
assert loc is not None
assert scale is not None
assert seed is not None
super(NormalInitializer, self).__init__()
self._mean = loc
self._std_dev = scale
self._seed = seed
def __call__(self, var, block):
"""Add normal distribution initialization ops for a variable
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
Returns:
the initialization op
"""
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended
op = block.prepend_op(
type="gaussian_random",
outputs={"Out": var},
attrs={
"shape": var.shape,
"data_type": int(var.data_type),
"mean": self._mean,
"std": self._std_dev,
"seed": self._seed
})
var.op = op
return op
...@@ -19,7 +19,7 @@ class TestGaussianRandomOp(unittest.TestCase): ...@@ -19,7 +19,7 @@ class TestGaussianRandomOp(unittest.TestCase):
op = Operator( op = Operator(
"gaussian_random", "gaussian_random",
Out='Out', Out='Out',
dims=[1000, 784], shape=[1000, 784],
mean=.0, mean=.0,
std=1., std=1.,
seed=10) seed=10)
......
import unittest
import paddle.v2.framework.framework as framework
import paddle.v2.framework.initializer as initializer
DELTA = 0.00001
class TestConstantInitializer(unittest.TestCase):
def test_constant_initializer_default_value(self):
"""Test the constant initializer with default value
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.ConstantInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'fill_constant')
self.assertAlmostEqual(init_op.attr('value'), 0.0, delta=DELTA)
def test_constant_initializer(self):
"""Test constant initializer with supplied value
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.ConstantInitializer(2.3))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'fill_constant')
self.assertAlmostEqual(init_op.attr('value'), 2.3, delta=DELTA)
class TestUniformInitializer(unittest.TestCase):
def test_uniform_initializer_default_value(self):
"""Test the uniform initializer with default value
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
self.assertAlmostEqual(init_op.attr('min'), -1.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), 1.0, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_uniform_initializer(self):
"""Test uniform initializer with supplied attributes
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.UniformInitializer(-4.2, 3.1, 123))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
self.assertAlmostEqual(init_op.attr('min'), -4.2, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), 3.1, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 123)
class TestNormalInitializer(unittest.TestCase):
def test_normal_initializer_default_value(self):
"""Test the normal initializer with default value
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.NormalInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), 1.0, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_normal_initializer(self):
"""Test normal initializer with supplied attributes
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.NormalInitializer(2.3, 1.9, 123))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
self.assertAlmostEqual(init_op.attr('mean'), 2.3, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), 1.9, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 123)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册