提交 0d7e4294 编写于 作者: S superjom

remove alias

上级 8162e2c7
......@@ -28,9 +28,7 @@ using Variable = framework::Variable;
using Tensor = framework::Tensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>()
->dims()[0];
seq_len_ = scope.FindVar(arg_->inlinks[0])->GetMutable<Tensor>()->dims()[0];
CreateScopes(scope);
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
......@@ -121,14 +119,12 @@ void RecurrentAlgorithm::InitMemories(Scope* step_scope,
}
const rnn::ArgumentName RecurrentOp::kArgName{
"step_net", "step_scopes", "inlinks",
"outlinks", "inlink_alias", "outlink_alias",
"step_net", "step_scopes", "inlinks", "outlinks",
"memories", "pre_memories", "boot_memories"};
const rnn::ArgumentName RecurrentGradientOp::kArgName{
"step_net", "step_scopes", "outlink@grad",
"inlink@grad", "inlink_alias", "outlink_alias",
"memories", "pre_memories", "boot_memories@grad"};
"step_net", "step_scopes", "outlink@grad", "inlink@grad",
"memories", "pre_memories", "boot_memories@grad"};
RecurrentOp::RecurrentOp(const std::string& type,
const framework::VariableNameMap& inputs,
......@@ -158,8 +154,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker
AddOutput(name.step_scopes, "step scopes");
// Attributes stored in AttributeMap
AddAttr<std::vector<std::string>>(name.inlink_alias, "alias of inlinks");
AddAttr<std::vector<std::string>>(name.outlink_alias, "alias of outlinks");
AddAttr<std::vector<std::string>>(name.pre_memories,
"names of pre-memories");
AddAttr<std::vector<std::string>>(name.memories, "names of memories");
......@@ -204,9 +198,7 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
}
void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>()
->dims()[0];
seq_len_ = scope.FindVar(arg_->inlinks[0])->GetMutable<Tensor>()->dims()[0];
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
......
......@@ -23,13 +23,13 @@ namespace f = paddle::framework;
using Tensor = framework::Tensor;
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len,
bool infer_shape_mode) {
const std::vector<std::string>& inlinks,
const size_t seq_len, bool infer_shape_mode) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) {
auto input_var = step_scopes[0]->FindVar(inlinks[i].external);
auto input_var = step_scopes[0]->FindVar(inlinks[i]);
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
inlinks[i].external);
inlinks[i]);
Tensor* input = input_var->GetMutable<Tensor>();
f::DDim dims = input->dims();
......@@ -38,7 +38,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>();
if (!infer_shape_mode) {
*step_input = input->Slice<float>(j, j + 1);
}
......@@ -48,18 +48,17 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
}
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const size_t seq_len,
bool infer_shape_mode) {
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode) {
for (size_t i = 0; i < outlinks.size(); i++) {
auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
auto output_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
outlinks[i].external);
outlinks[i]);
Tensor* output = output_var->GetMutable<Tensor>();
if (infer_shape_mode) {
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal);
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", outlinks[i]);
f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
......@@ -68,7 +67,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_output =
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
step_scopes[j]->FindVar(outlinks[i])->GetMutable<Tensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
......@@ -108,29 +107,9 @@ void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op) {
arg->step_scopes = op.Output(name.step_scopes);
auto inlinks = op.Inputs(name.inlinks);
auto inlink_alias = op.Attr<std::vector<std::string>>(name.inlink_alias);
PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(),
"the size of inlinks and inlink_alias don't match:%d,%d",
inlinks.size(), inlink_alias.size());
for (size_t i = 0; i < inlinks.size(); ++i) {
rnn::Link link;
link.external = inlinks[i];
link.internal = inlink_alias[i];
(arg->inlinks).push_back(link);
}
arg->inlinks = op.Inputs(name.inlinks);
auto outlinks = op.Outputs(name.outlinks);
auto outlink_alias = op.Attr<std::vector<std::string>>(name.outlink_alias);
PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(),
"the size of outlinks and outlink_alias don't match:%d,%d",
outlinks.size(), outlink_alias.size());
for (size_t i = 0; i < outlinks.size(); ++i) {
rnn::Link link;
link.external = outlinks[i];
link.internal = outlink_alias[i];
(arg->outlinks).push_back(link);
}
arg->outlinks = op.Outputs(name.outlinks);
auto boot_memories = op.Inputs(name.boot_memories);
......
......@@ -41,18 +41,11 @@ struct MemoryAttr {
std::string boot_var;
};
struct Link {
// input or output links name.
std::string internal;
// alias to avoid duplicate keys in scopes.
std::string external;
};
struct Argument {
std::string step_net;
std::string step_scopes;
std::vector<Link> inlinks;
std::vector<Link> outlinks;
std::vector<std::string> inlinks;
std::vector<std::string> outlinks;
std::vector<rnn::MemoryAttr> memories;
};
......@@ -61,8 +54,6 @@ struct ArgumentName {
std::string step_scopes;
std::string inlinks;
std::string outlinks;
std::string inlink_alias; // the alias of inlinks in step net.
std::string outlink_alias; // the alias of outlinks in step net.
std::string memories; // the memory name
std::string pre_memories; // the previous memory name
std::string boot_memories; // the boot memory name
......@@ -72,15 +63,15 @@ struct ArgumentName {
* Prepare inputs for each step net.
*/
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len,
bool infer_shape_mode);
const std::vector<std::string>& inlinks,
const size_t seq_len, bool infer_shape_mode);
/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& outlinks, const size_t seq_len,
bool infer_shape_mode);
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode);
void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const size_t step_id,
......
......@@ -123,7 +123,6 @@ class TestRecurrentOp(unittest.TestCase):
create_tensor(self.scope, "h_boot", [self.batch_size, self.input_dim],
h_boot_np_data)
self.scope.new_var("step_scopes")
self.scope.new_var("h@alias")
self.scope.new_var("h")
def create_rnn_op(self):
......@@ -137,17 +136,15 @@ class TestRecurrentOp(unittest.TestCase):
outlinks=["h"],
step_scopes="step_scopes",
# attributes
inlink_alias=["x@alias"],
outlink_alias=["h@alias"],
pre_memories=["h@pre"],
memories=["h@alias"])
memories=["h@mem"])
def create_step_net(self):
stepnet = core.Net.create()
x_fc_op = Operator("mul", X="x@alias", Y="W", Out="Wx")
x_fc_op = Operator("mul", X="x", Y="W", Out="Wx")
h_fc_op = Operator("mul", X="h@pre", Y="U", Out="Uh")
sum_op = Operator("add", X="Wx", Y="Uh", Out="sum")
sig_op = Operator("sigmoid", X="sum", Y="h@alias")
sig_op = Operator("sigmoid", X="sum", Y="h@mem")
for op in [x_fc_op, h_fc_op, sum_op, sig_op]:
stepnet.append_op(op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册