提交 e71b836f 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_opdesc_in_python

......@@ -24,6 +24,10 @@ if(WITH_DOUBLE)
add_definitions(-DPADDLE_TYPE_DOUBLE)
endif(WITH_DOUBLE)
if(WITH_TESTING)
add_definitions(-DPADDLE_WITH_TESTING)
endif(WITH_TESTING)
if(NOT WITH_TIMER)
add_definitions(-DPADDLE_DISABLE_TIMER)
endif(NOT WITH_TIMER)
......
......@@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block
```python
class Program(objects):
def __init__(self):
self.proto = core.NewProgram() # a C++ ProgramDesc pointer.
self.desc = core.NewProgram() # a C++ ProgramDesc pointer.
self.blocks = vector<Block>()
self.blocks.append(Block(self, -1)) # the global block
self.current_block = 0 # initialized to the global block
......@@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m
```python
class Block(objects):
def __init__(self, program, parent_idx):
self.proto = core.NewBlock(program.proto)
self.desc = core.NewBlock(program.desc)
self.program = program
self.vars = map<string, Variable>()
self.ops = vector<Operator>()
......@@ -98,11 +98,11 @@ class Operator(object):
outputs,# dict<stirng, Variable>
attrs # dict<string, Any>
):
self.proto = core.NewOpDesc(block.proto, type, inputs, outputs, attrs)
core.infer_shape(self.proto, inputs, outputs)
self.desc = core.NewOpDesc(block.desc, type, inputs, outputs, attrs)
core.infer_shape(self.desc, inputs, outputs)
def type(self):
return self.proto.type()
return self.desc.type()
```
`Operator` creates the `OpDesc` message in C++ space, so that it can call the `InferShape` function, which is in C++.
......@@ -124,7 +124,7 @@ class Variable(object):
name = unique_name_generator()
self.name = name
self.block = block
self.proto = core.NewVarDesc(block.proto, name, shape, lod_level)
self.desc = core.NewVarDesc(block.desc, name, shape, lod_level)
self.writer = None
```
......
......@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
......
......@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_desc.h"
#include <functional>
#include <unordered_map>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
......@@ -190,5 +193,38 @@ void OpDescBind::Sync() {
need_update_ = false;
}
}
using InferShapeFuncMap =
std::unordered_map<std::string /*op_type*/,
std::function<void(InferShapeContext *)>>;
static InferShapeFuncMap &InferShapeFuncs() {
static InferShapeFuncMap *g_map = nullptr;
if (g_map == nullptr) {
g_map = new InferShapeFuncMap();
auto &info_map = OpInfoMap::Instance();
// all registered kernels
for (auto &pair : OperatorWithKernel::AllOpKernels()) {
auto &info = info_map.Get(pair.first);
// use empty type here to avoid runtime checks.
auto op =
static_cast<OperatorWithKernel *>(info.Creator()("", {}, {}, {}));
g_map->insert(
{pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
}
}
return *g_map;
}
void OpDescBind::InferShape(const BlockDescBind &block) const {
auto &funcs = InferShapeFuncs();
auto it = funcs.find(this->Type());
if (it == funcs.end()) {
PADDLE_THROW("Operator %s has not been registered", this->Type());
}
CompileTimeInferShapeContext ctx(*this, block);
it->second(&ctx);
}
} // namespace framework
} // namespace paddle
......@@ -100,6 +100,8 @@ class OpDescBind {
return &this->attrs_;
}
void InferShape(const BlockDescBind &block) const;
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
......@@ -142,9 +142,9 @@ class OperatorBase {
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<OperatorBase> Clone() const final { \
return std::unique_ptr<OperatorBase>(new cls(*this)); \
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \
return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
}
// Macro for define a default constructor for Operator.
......
......@@ -87,12 +87,12 @@ class TensorArray {
LoDTensor Stack() const;
/*
* Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
* Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors.
*/
void Unstack(const LoDTensor &source) const;
/*
* Unpacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* Unstacks the given division of a rank-`R` tensor into rank-`(R-1)` tensors,
* with memory of tensors shared.
*/
void UnstackShared(const LoDTensor &source) const;
......
......@@ -32,5 +32,13 @@ std::vector<int64_t> VarDescBind::Shape() const {
DataType VarDescBind::GetDataType() const {
return desc_.lod_tensor().data_type();
}
void VarDescBind::SetLoDLevel(int32_t lod_level) {
desc_.mutable_lod_tensor()->set_lod_level(lod_level);
}
int32_t VarDescBind::GetLodLevel() const {
return desc_.lod_tensor().lod_level();
}
} // namespace framework
} // namespace paddle
......@@ -66,6 +66,10 @@ class VarDescBind {
DataType GetDataType() const;
void SetLoDLevel(int32_t lod_level);
int32_t GetLodLevel() const;
private:
VarDesc desc_;
};
......
......@@ -133,3 +133,4 @@ cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve .
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/dynamic_recurrent_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
namespace detail {
inline void CreateVariables(Scope& scope,
const std::vector<std::string>& var_names) {
for (const auto& name : var_names) {
scope.NewVar(name);
}
}
} // namespace detail
class DynamicRecurrentOpProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public:
DynamicRecurrentOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = DynamicRecurrentOp::kArgName;
// inputs and outputs stored in proto
AddInput(name.inlinks,
"the inputs that need to be segmented for each step.")
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
.AsDuplicable();
AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.pre_memories,
"names of pre-memories");
AddAttr<std::vector<std::string>>(name.memories, "names of memories");
AddComment("This is a RNN operator for varience-length sequences.");
}
};
void DynamicRecurrentOp::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
cache_.Init(kArgName, *this, scope, &arg_);
SplitInputs();
CreateScopes();
WriteStepInputs();
InitStates();
// call stepnet in all the time steps
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& step_scope = cache_.GetScope(step);
stepnet_->Run(step_scope, dev_ctx);
}
WriteStepOutputs();
ConcatOutputs();
}
void DynamicRecurrentOp::SplitInputs() const {
// TODO(superjom) make level a config
// TODO(superjom) check all the inputs has the same LoD
int level = 0;
const auto& inlinks = cache_.inlinks;
for (const auto& item : inlinks) {
const auto& var = item.second;
const auto& tensor = var->Get<LoDTensor>();
TensorArray& ta = step_inputs_[item.first];
dy_seq_metas_[item.first] =
ta.Unpack(tensor, level, true /*length_descend*/);
if (cache_.num_steps) {
PADDLE_ENFORCE_EQ(ta.size(), cache_.num_steps,
"inputs should have the same steps");
} else {
cache_.num_steps = ta.size();
}
}
}
void DynamicRecurrentOp::WriteStepInputs() const {
for (const auto& item : cache_.inlinks) {
auto ta_it = step_inputs_.find(item.first);
PADDLE_ENFORCE(ta_it != step_inputs_.end(),
"step_inputs_ not compatible with memory set");
TensorArray& ta = ta_it->second;
for (size_t step = 0; step < ta.size(); step++) {
auto tensor = ta.Read(step);
auto& step_scope = cache_.GetScope(step);
Variable* var = step_scope.FindVar(item.first);
if (var == nullptr) {
var = step_scope.NewVar(item.first);
}
var->GetMutable<LoDTensor>()->ShareDataWith<value_type>(tensor);
}
}
}
void DynamicRecurrentOp::WriteStepOutputs() const {
for (size_t step = 0; step < cache_.scopes->size(); step++) {
auto& scope = cache_.GetScope(step);
for (auto& item : step_outputs_) {
auto* var = scope.FindVar(item.first);
if (var == nullptr) {
var = scope.NewVar(item.first);
}
auto* tensor = var->GetMutable<LoDTensor>();
item.second.WriteShared(step, *tensor);
}
}
}
void DynamicRecurrentOp::CreateScopes() const {
PADDLE_ENFORCE_GT(cache_.num_steps, 0);
// resize scopes
size_t num_scopes_need_create = cache_.num_steps - cache_.scopes->size();
for (size_t i = 0; i < num_scopes_need_create; i++) {
cache_.scopes->emplace_back(&cache_.scope->NewScope());
}
// init temporary inputs
PADDLE_ENFORCE_NOT_NULL(stepnet_, "stepnet should be set first");
std::vector<std::string> memories;
std::vector<std::string> pre_memories;
std::transform(arg_.memories.begin(), arg_.memories.end(),
std::back_inserter(memories),
[](const rnn::MemoryAttr& m) { return m.var; });
std::transform(arg_.memories.begin(), arg_.memories.end(),
std::back_inserter(pre_memories),
[](const rnn::MemoryAttr& m) { return m.pre_var; });
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& scope = cache_.GetScope(step);
detail::CreateVariables(scope, arg_.inlinks);
detail::CreateVariables(scope, arg_.outlinks);
detail::CreateVariables(scope, memories);
detail::CreateVariables(scope, pre_memories);
}
}
void DynamicRecurrentOp::ConcatOutputs() const {
// TODO(superjom) transform this to a config
int level = 0;
// TODO(superjom) pass in some lod
// just a placeholder
framework::LoD lod;
for (auto& item : step_outputs_) {
auto tensor = item.second.Pack(level, dy_seq_metas_[item.first], lod);
auto& output = cache_.outlinks[item.first]->Get<LoDTensor>();
const_cast<LoDTensor*>(&output)->ShareDataWith<value_type>(tensor);
}
}
void DynamicRecurrentOp::InitStates() const {
// init the first state
// TODO(superjom) parepare the scenerio that boot state not exists
for (auto memory : arg_.memories) {
auto* boot_state_var = cache_.scope->FindVar(memory.boot_var);
PADDLE_ENFORCE_NOT_NULL(boot_state_var);
auto& boot_state = boot_state_var->Get<LoDTensor>();
const auto& dims = boot_state.dims();
for (size_t step = 0; step < cache_.num_steps; step++) {
auto& cur_scope = cache_.GetScope(step);
// link pre-state to boot_state
// init state and pre-state
auto* pre_state = cur_scope.FindVar(memory.pre_var);
PADDLE_ENFORCE_NOT_NULL(pre_state);
pre_state->GetMutable<LoDTensor>();
auto* state = cur_scope.FindVar(memory.var);
PADDLE_ENFORCE_NOT_NULL(state);
state->GetMutable<LoDTensor>()->Resize(dims);
state->GetMutable<LoDTensor>()->mutable_data<value_type>(
platform::CPUPlace());
if (step == 0) {
auto* pre_state_tensor = pre_state->GetMutable<LoDTensor>();
pre_state_tensor->Resize(boot_state.dims());
pre_state_tensor->ShareDataWith<value_type>(boot_state);
} else {
auto& pre_scope = cache_.GetScope(step - 1);
auto* state_pre = pre_scope.FindVar(memory.var);
PADDLE_ENFORCE_NOT_NULL(state_pre);
pre_state->GetMutable<LoDTensor>()->ShareDataWith<value_type>(
*state_pre->GetMutable<LoDTensor>());
}
}
}
}
void DynamicRecurrentOp::ArgCache::Init(
const rnn::ArgumentName& name, const paddle::framework::OperatorBase& op,
const paddle::framework::Scope& scope, rnn::Argument* arg) {
this->scope = &scope;
InitArgument(name, op, arg);
CacheScopes(scope, *arg);
CacheInlinks(scope, arg->inlinks);
CacheOutlinks(scope, arg->outlinks);
}
void DynamicRecurrentOp::ArgCache::InitArgument(const rnn::ArgumentName& name,
const OperatorBase& op,
rnn::Argument* arg) {
rnn::InitArgument(name, arg, op, false /*is_grad*/);
}
void DynamicRecurrentOp::ArgCache::CacheScopes(const Scope& scope,
const rnn::Argument& arg) {
auto scopes_var = scope.FindVar(arg.step_scopes);
PADDLE_ENFORCE(scopes_var != nullptr,
"the step_scopes output argument [%s] should be created first "
"by framework.",
arg.step_scopes);
this->scopes = scopes_var->GetMutable<std::vector<Scope*>>();
}
void DynamicRecurrentOp::ArgCache::CacheInlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
inlinks[name] = var;
}
}
void DynamicRecurrentOp::ArgCache::CacheOutlinks(
const Scope& scope, const std::vector<std::string>& names) {
for (auto name : names) {
auto* var = GetVariable(scope, name);
outlinks[name] = var;
}
}
Variable* DynamicRecurrentOp::ArgCache::GetVariable(const Scope& scope,
const std::string& name) {
auto* var = scope.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, "variable [%s] not exist in scope", name);
return var;
}
const rnn::ArgumentName DynamicRecurrentOp::kArgName{
"step_net", "step_scopes", "inlinks", "outlinks",
"memories", "pre_memories", "boot_memories"};
void DynamicRecurrentGradientOp::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {}
} // namespace operators
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
dynamic_recurrent, paddle::operators::DynamicRecurrentOp,
paddle::operators::DynamicRecurrentOpProtoAndCheckerMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_TESTING
#include "gtest/gtest.h"
#endif
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor_array.h"
#include "paddle/framework/variable.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"
namespace paddle {
namespace operators {
class DynamicRecurrentOp : public framework::OperatorBase {
public:
static const rnn::ArgumentName kArgName;
using value_type = float;
DynamicRecurrentOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
DynamicRecurrentOp(const DynamicRecurrentOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW("Not implemented");
}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
/*
* Split the inputs(LoDTensors) to segments for each time step.
*/
void SplitInputs() const;
/*
* Create step-scopes to store temporary outputs in each time steps.
*/
void CreateScopes() const;
/*
* Link TensorArray steps to the corresponding variables located in
* step-scopes.
*/
void WriteStepInputs() const;
/*
* Write output of each step to the corresponding TensorArray.
*/
void WriteStepOutputs() const;
/*
* Initialize the states, each state will have a corresponding pre-state,
* which share the memory with the state in the previous time state. The
* pre-state in the first time step will be initialized with an zero tensor or
* a tensor in parent scope if is provided.
*/
void InitStates() const;
/*
* Concatenate outputs in each time step and generate a LoDTensor.
*/
void ConcatOutputs() const;
/*
* set a stepnet that is created according to a RecurrentOp's stepnet.
*/
void SetStepNet(std::unique_ptr<OperatorBase> net) {
PADDLE_ENFORCE_NOT_NULL(net);
stepnet_ = std::move(net);
}
const OperatorBase& GetStepNet() const { return *stepnet_; }
protected:
struct ArgCache {
framework::Scope const* scope;
std::vector<framework::Scope*>* scopes;
std::map<std::string, framework::Variable*> inlinks;
std::map<std::string, framework::Variable*> outlinks;
size_t num_steps{0};
void Init(const rnn::ArgumentName& name, const OperatorBase& op,
const framework::Scope& scope, rnn::Argument* arg);
framework::Scope& GetScope(size_t index) {
PADDLE_ENFORCE_LT(index, num_steps);
return *scopes->at(index);
}
private:
void InitArgument(const rnn::ArgumentName& name, const OperatorBase& op,
rnn::Argument* arg);
void CacheScopes(const framework::Scope& scope, const rnn::Argument& arg);
void CacheInlinks(const framework::Scope& scope,
const std::vector<std::string>& names);
void CacheOutlinks(const framework::Scope& scope,
const std::vector<std::string>& names);
framework::Variable* GetVariable(const framework::Scope& scope,
const std::string& name);
};
private:
std::unique_ptr<OperatorBase> stepnet_;
mutable framework::TensorArray states_;
mutable std::map<std::string, framework::TensorArray> step_inputs_;
mutable std::map<std::string, framework::TensorArray> step_outputs_;
mutable std::map<std::string, std::vector<framework::DySeqMeta>>
dy_seq_metas_;
mutable rnn::Argument arg_;
mutable ArgCache cache_;
#ifdef PADDLE_WITH_TESTING
friend class DynamicRecurrentOpTestHelper;
FRIEND_TEST(DynamicRecurrentOpTestHelper, SplitInputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateCache);
FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateScopes);
FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepInputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepOutputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, InitStates);
FRIEND_TEST(DynamicRecurrentOpTestHelper, ConcatOutputs);
#endif
};
class DynamicRecurrentGradientOp : public framework::OperatorBase {
public:
DynamicRecurrentGradientOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override;
};
} // namespace operators
} // namespace paddle
#include "paddle/operators/dynamic_recurrent_op.h"
#include <gtest/gtest.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
using framework::Scope;
using framework::TensorArray;
using framework::LoDTensor;
using framework::Variable;
class TestOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
void OpDescNewVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::OpDesc::Var* var) {
var->set_parameter(param_name);
for (auto& arg_name : arguments) {
var->add_arguments(arg_name);
}
}
// create a LoD tensor in scope with specific dims
LoDTensor* CreateVar(Scope& scope, std::string name, framework::DDim dims,
const platform::Place& place) {
auto* var = scope.NewVar(name);
auto* tensor = var->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(place);
return tensor;
}
class DynamicRecurrentOpTestHelper : public ::testing::Test {
protected:
const rnn::ArgumentName argname = DynamicRecurrentOp::kArgName;
virtual void SetUp() override {
CreateGlobalVariables();
auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc);
dop = dynamic_cast<DynamicRecurrentOp*>(op.get());
InitCacheManually();
InitStepNet();
}
framework::OpDesc CreateOpDesc() {
// create op
paddle::framework::OpDesc op_desc;
op_desc.set_type("dynamic_recurrent");
OpDescNewVar(argname.inlinks, {"in0"}, op_desc.add_inputs());
OpDescNewVar(argname.boot_memories, {"boot_mem"}, op_desc.add_inputs());
OpDescNewVar(argname.step_scopes, {"step_scopes"}, op_desc.add_outputs());
OpDescNewVar(argname.outlinks, {"out0"}, op_desc.add_outputs());
// set pre-memories
auto pre_memories = op_desc.mutable_attrs()->Add();
pre_memories->set_name(argname.pre_memories);
pre_memories->set_type(paddle::framework::AttrType::STRINGS);
auto pre_memories_item = pre_memories->add_strings();
*pre_memories_item = "mem@pre";
// set memories
auto memories = op_desc.mutable_attrs()->Add();
memories->set_name(argname.memories);
memories->set_type(paddle::framework::AttrType::STRINGS);
auto memories_item = memories->add_strings();
*memories_item = "mem";
return op_desc;
}
void CreateGlobalVariables() {
platform::CPUPlace place;
scope.NewVar("step_scopes");
CreateVar(scope, "boot_mem", framework::make_ddim({10, 20}), place);
// auto* out0 =
CreateVar(scope, "out0", framework::make_ddim({10, 20}), place);
auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place);
// 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively.
framework::LoD in0_lod(1);
for (int x : std::vector<int>{0, 4, 7, 9, 10}) {
in0_lod[0].push_back(x);
}
in0->set_lod(in0_lod);
in0->Resize(framework::make_ddim({10, 8}));
// set the content, each sentence content is seqid.batchid
// the seqid starts from 0
int start = 0;
for (size_t seqid = 0; seqid < in0_lod.size() - 1; seqid++) {
for (size_t batchid = 0;
batchid < in0_lod[0][seqid + 1] - in0_lod[0][seqid]; batchid++) {
float v = seqid + batchid * 0.1;
for (size_t dim = 0; dim < 8; dim++) {
in0->data<float>()[start * 8 + dim] = v;
}
start++;
}
}
}
void InitCacheManually() {
dop->cache_.Init(DynamicRecurrentOp::kArgName, *dop, scope, &dop->arg_);
}
void InitStepNet() {
std::unique_ptr<framework::OperatorBase> stepnet{new NetOp};
dynamic_cast<NetOp*>(stepnet.get())
->AppendOp(std::unique_ptr<TestOp>(new TestOp(
"test", {{"inlinks", {"in0"}}, {"boot_memories", {"boot_mem"}}},
{{"outlinks", {"out0"}}, {"step_scopes", {"step_scopes"}}}, {})));
dop->SetStepNet(std::move(stepnet));
}
protected:
DynamicRecurrentOp* dop;
std::unique_ptr<framework::OperatorBase> op;
paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope;
};
TEST_F(DynamicRecurrentOpTestHelper, CreateCache) {
const rnn::Argument& arg = dop->arg_;
ASSERT_EQ(arg.inlinks.size(), 1UL);
ASSERT_EQ(arg.outlinks.size(), 1UL);
}
TEST_F(DynamicRecurrentOpTestHelper, SplitInputs) {
dop->SplitInputs();
auto& in0_ta = dop->step_inputs_["in0"];
ASSERT_EQ(in0_ta.size(), 4UL);
const auto& batch0 = in0_ta.Read(0);
const auto& batch1 = in0_ta.Read(1);
const auto& batch2 = in0_ta.Read(2);
const auto& batch3 = in0_ta.Read(3);
EXPECT_EQ(batch0.dims()[0], 4);
EXPECT_EQ(batch1.dims()[0], 3);
EXPECT_EQ(batch2.dims()[0], 2);
EXPECT_EQ(batch3.dims()[0], 1);
}
TEST_F(DynamicRecurrentOpTestHelper, CreateScopes) {
dop->SplitInputs();
dop->CreateScopes();
ASSERT_EQ(dop->cache_.num_steps, 4UL);
ASSERT_EQ(dop->cache_.scopes->size(), 4UL);
}
TEST_F(DynamicRecurrentOpTestHelper, WriteStepInputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
for (auto name : std::vector<std::string>({"in0"})) {
ASSERT_TRUE(scope.FindVar(name) != nullptr);
}
}
}
TEST_F(DynamicRecurrentOpTestHelper, WriteStepOutputs) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
dop->WriteStepOutputs();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
for (auto name : std::vector<std::string>({"out0"})) {
ASSERT_TRUE(scope.FindVar(name));
}
}
}
TEST_F(DynamicRecurrentOpTestHelper, ConcatOutputs) {
// Let's leave this test to python unittest.
}
TEST_F(DynamicRecurrentOpTestHelper, InitStates) {
dop->SplitInputs();
dop->CreateScopes();
dop->WriteStepInputs();
dop->WriteStepOutputs();
dop->InitStates();
for (size_t step = 0; step < dop->cache_.num_steps; step++) {
auto& scope = dop->cache_.GetScope(step);
auto state = scope.FindVar("mem");
ASSERT_TRUE(state != nullptr);
auto* pre_state = scope.FindVar("mem@pre");
ASSERT_TRUE(pre_state != nullptr);
auto* boot_state = scope.FindVar("boot_mem");
ASSERT_TRUE(boot_state != nullptr);
if (step == 0) {
// check pre_state is a reference of boot_state
ASSERT_EQ(boot_state->Get<LoDTensor>().data<float>(),
pre_state->Get<LoDTensor>().data<float>());
}
}
}
} // operators
} // namespace paddle
......@@ -166,7 +166,9 @@ void BindVarDsec(py::module &m) {
.def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
.def("data_type", &VarDescBind::GetDataType);
.def("data_type", &VarDescBind::GetDataType)
.def("lod_level", &VarDescBind::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel);
}
void BindOpDesc(py::module &m) {
......@@ -196,7 +198,8 @@ void BindOpDesc(py::module &m) {
.def("set_attr", &OpDescBind::SetAttr)
.def("attr", &OpDescBind::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr)
.def("block_attr", &OpDescBind::GetBlockAttr);
.def("block_attr", &OpDescBind::GetBlockAttr)
.def("infer_shape", &OpDescBind::InferShape);
}
} // namespace pybind
......
......@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def_static("infer_shape",
[](OpDescBind &op_desc, BlockDescBind &block) {
auto op = OpRegistry::CreateOp(*op_desc.Proto());
auto *op_with_kernel =
dynamic_cast<OperatorWithKernel *>(op.get());
if (op_with_kernel != nullptr) {
auto ctx = CompileTimeInferShapeContext(op_desc, block);
op_with_kernel->InferShape(&ctx);
} else {
PADDLE_THROW(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function",
op_desc.Type());
}
})
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
......
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.framework_pb2 as framework_pb2
import collections
import numpy as np
import copy
__all__ = ['Block', 'Variable', 'Program', 'Operator']
......@@ -40,35 +42,104 @@ class OpProtoHolder(object):
class Variable(object):
def __init__(self, block, name=None, shape=None, dtype=None,
lod_level=None):
def __init__(self,
block,
name=None,
shape=None,
dtype=None,
lod_level=None,
**kwargs):
self.block = block
if name is None:
name = Variable._unique_var_name_()
self.desc = self.block.desc.new_var(name)
try:
self.desc = self.block.desc.var(name)
is_new_var = False
except core.EnforceNotMet:
self.desc = self.block.desc.new_var(name)
is_new_var = True
if shape is not None:
self.desc.set_shape(shape)
if is_new_var:
self.desc.set_shape(shape)
else:
old_shape = self.shape
shape = tuple(shape)
if shape != old_shape:
raise ValueError(
"Variable {0} has been created before. the previous "
"shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape))
if dtype is not None:
# TODO(yuyang18): Convert dtype from numpy.dtype
self.desc.set_data_type(dtype)
if not isinstance(dtype, core.DataType):
dtype = Variable._convert_np_dtype_to_dtype_(dtype)
if is_new_var:
self.desc.set_data_type(dtype)
else:
old_dtype = self.data_type()
if dtype != old_shape:
raise ValueError("Variable {0} has been created before. "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
"matched.".format(self.name, old_dtype,
dtype))
if lod_level is not None:
# TODO(yuyang18): set_lod_level is not defined.
self.desc.set_lod_level(lod_level)
if is_new_var:
self.desc.set_lod_level(lod_level)
else:
if lod_level != self.lod_level:
raise ValueError("Variable {0} has been created before. "
"The previous lod_level is {1}; the new "
"lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level,
lod_level))
self.block.vars[name] = self
self.op = None
# TODO(yuyang18): Get methods
@property
def name(self):
return self.desc.name()
@property
def shape(self):
# convert to tuple, make it as same as numpy API.
return tuple(self.desc.shape())
@property
def data_type(self):
return self.desc.data_type()
@property
def lod_level(self):
return self.desc.lod_level()
@staticmethod
def _unique_var_name_():
uid = core.unique_integer() # unique during whole process.
return "_generated_var_%d" % uid
@staticmethod
def _convert_np_dtype_to_dtype_(np_dtype):
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.DataType.FP32
elif dtype == np.float64:
return core.DataType.FP64
elif dtype == np.float16:
return core.DataType.FP16
elif dtype == np.int32:
return core.DataType.INT32
elif dtype == np.int16:
return core.DataType.INT16
elif dtype == np.int64:
return core.DataType.INT64
elif dtype == np.bool:
return core.DataType.BOOL
else:
raise ValueError("Not supported numpy dtype " + str(dtype))
class Operator(object):
def __init__(self, block, desc, type, inputs=None, outputs=None,
......@@ -169,6 +240,10 @@ class Block(object):
def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs)
def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block()
return Parameter(global_block, *args, **kwargs)
def append_op(self, *args, **kwargs):
op_desc = self.desc.append_op()
op = Operator(self, op_desc, *args, **kwargs)
......@@ -215,5 +290,41 @@ class Program(object):
self.current_block_idx = self.current_block().parent_idx
class Parameter(Variable):
def __init__(self, block, shape, dtype, **kwargs):
if shape is None or dtype is None:
raise ValueError("Parameter must set shape and dtype")
if len(shape) == 0:
raise ValueError("Parameter shape cannot be empty")
for each in shape:
if each < 0:
raise ValueError("Parameter shape should not be related with "
"batch-size")
Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs)
self.trainable = kwargs.get('trainable', True)
self.init_attr = kwargs.get('initialize_attr', {
'type': 'uniform_random',
'min': -1.0,
'max': 1.0
})
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self._append_initialize_ops_()
def _append_initialize_ops_(self):
attr = copy.deepcopy(self.init_attr)
op_type = attr.pop('type', None)
block = self.block
assert isinstance(block, Block)
shape = self.shape
attr['dims'] = shape
attr['data_type'] = int(self.data_type)
op = block.prepend_op(
type=op_type, inputs=None, outputs={'Out': [self]}, attrs=attr)
self.op = op
# program is a global instance.
g_program = Program.instance()
import unittest
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestInferShape(unittest.TestCase):
......@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc.set_input("X", ["x1", "x2"])
sum_op_desc.set_output("Out", ["out"])
core.Operator.infer_shape(sum_op_desc, block)
sum_op_desc.infer_shape(block)
self.assertEqual(out.shape(), shape)
def test_mul_op(self):
......@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc.set_attr("x_num_col_dims", 1)
mul_op_desc.set_attr("y_num_col_dims", 1)
core.Operator.infer_shape(mul_op_desc, block)
mul_op_desc.infer_shape(block)
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
......
import unittest
from paddle.v2.framework.graph import g_program
import paddle.v2.framework.core as core
class TestParameter(unittest.TestCase):
def test_param(self):
b = g_program.create_block()
param = b.create_parameter(
name='fc.w',
shape=[784, 100],
dtype='float32',
initialize_attr={
'type': 'uniform_random',
'seed': 13,
'min': -5.0,
'max': 5.0
})
self.assertIsNotNone(param)
self.assertEqual('fc.w', param.name)
self.assertEqual((784, 100), param.shape)
self.assertEqual(core.DataType.FP32, param.data_type)
self.assertEqual(0, param.block.idx)
if __name__ == '__main__':
unittest.main()
import unittest
from paddle.v2.framework.graph import Variable, g_program
import paddle.v2.framework.core as core
import numpy as np
class TestVariable(unittest.TestCase):
def test_np_dtype_convert(self):
DT = core.DataType
convert = Variable._convert_np_dtype_to_dtype_
self.assertEqual(DT.FP32, convert(np.float32))
self.assertEqual(DT.FP16, convert("float16"))
self.assertEqual(DT.FP64, convert("float64"))
self.assertEqual(DT.INT32, convert("int32"))
self.assertEqual(DT.INT16, convert("int16"))
self.assertEqual(DT.INT64, convert("int64"))
self.assertEqual(DT.BOOL, convert("bool"))
self.assertRaises(ValueError, lambda: convert("int8"))
def test_var(self):
b = g_program.current_block()
w = b.create_var(
dtype="float64", shape=[784, 100], lod_level=0, name="fc.w")
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
w = b.create_var(name='fc.w')
self.assertEqual(core.DataType.FP64, w.data_type)
self.assertEqual((784, 100), w.shape)
self.assertEqual("fc.w", w.name)
self.assertEqual(0, w.lod_level)
self.assertRaises(ValueError,
lambda: b.create_var(name="fc.w", shape=(24, 100)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册