提交 3c29224e 编写于 作者: S superjom

remove alias

上级 b818e647
......@@ -58,6 +58,8 @@ class Scope {
/// nullptr if cannot find.
Variable* FindVar(const std::string& name) const;
const Scope& parent() const { return *parent_; }
/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;
......
......@@ -94,7 +94,7 @@ class PReluGradKernel : public framework::OpKernel {
Transform(context.device_context(), out_ptr, out_ptr + numel, dout_ptr,
dx_ptr, PReluGradFunctor<T>(alpha_ptr));
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready
// TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready
}
};
......
......@@ -29,8 +29,11 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ =
scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0];
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len_, 0);
CreateScopes(scope);
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
......
......@@ -28,14 +28,15 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
const size_t seq_len, bool infer_shape_mode) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) {
auto input_var = step_scopes[0]->FindVar(inlinks[i]);
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
inlinks[i]);
// global inputs
auto input_var = step_scopes[0]->parent().FindVar(inlinks[i]);
PADDLE_ENFORCE_NOT_NULL(input_var, "input link [%s] is not in scope.",
inlinks[i]);
LoDTensor* input = input_var->GetMutable<LoDTensor>();
f::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length");
PADDLE_ENFORCE_EQ(static_cast<size_t>(dims[0]), seq_len,
"all the inlinks be the same length");
f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
......@@ -54,15 +55,14 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode) {
for (size_t i = 0; i < outlinks.size(); i++) {
auto output_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
outlinks[i]);
auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
outlinks[i]);
LoDTensor* output = output_var->GetMutable<LoDTensor>();
if (infer_shape_mode) {
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal);
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
......
......@@ -59,7 +59,6 @@ class PySimpleRNNTest(unittest.TestCase):
def test_forward(self):
output = self.rnn.forward()
print 'output', output
def create_tensor(scope, name, shape, np_data):
......@@ -103,7 +102,7 @@ class TestRecurrentOp(unittest.TestCase):
ctx = core.DeviceContext.create(core.CPUPlace())
self.rnnop.infer_shape(self.scope)
self.rnnop.run(self.scope, ctx)
return np.array(self.scope.find_var("h").get_tensor())
return np.array(self.scope.find_var("h@mem").get_tensor())
def create_global_variables(self):
# create inlink
......@@ -123,7 +122,7 @@ class TestRecurrentOp(unittest.TestCase):
create_tensor(self.scope, "h_boot", [self.batch_size, self.input_dim],
h_boot_np_data)
self.scope.new_var("step_scopes")
self.scope.new_var("h")
self.scope.new_var("h@mem")
def create_rnn_op(self):
# create RNNOp
......@@ -133,7 +132,7 @@ class TestRecurrentOp(unittest.TestCase):
boot_memories=["h_boot"],
step_net="stepnet",
# outputs
outlinks=["h"],
outlinks=["h@mem"],
step_scopes="step_scopes",
# attributes
pre_memories=["h@pre"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册