提交 3c29224e 编写于 作者: S superjom

remove alias

上级 b818e647
...@@ -58,6 +58,8 @@ class Scope { ...@@ -58,6 +58,8 @@ class Scope {
/// nullptr if cannot find. /// nullptr if cannot find.
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
const Scope& parent() const { return *parent_; }
/// Find the scope or an ancestor scope that contains the given variable. /// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const; const Scope* FindScope(const Variable* var) const;
......
...@@ -94,7 +94,7 @@ class PReluGradKernel : public framework::OpKernel { ...@@ -94,7 +94,7 @@ class PReluGradKernel : public framework::OpKernel {
Transform(context.device_context(), out_ptr, out_ptr + numel, dout_ptr, Transform(context.device_context(), out_ptr, out_ptr + numel, dout_ptr,
dx_ptr, PReluGradFunctor<T>(alpha_ptr)); dx_ptr, PReluGradFunctor<T>(alpha_ptr));
// TODO (Zhuoyuan): add dalpha upgrade when GPU kernels ready // TODO(Zhuoyuan): add dalpha upgrade when GPU kernels ready
} }
}; };
......
...@@ -29,8 +29,11 @@ using Tensor = framework::Tensor; ...@@ -29,8 +29,11 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const { void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = auto* input0 = scope.FindVar(arg_->inlinks[0]);
scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0]; PADDLE_ENFORCE_NOT_NULL(input0);
seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len_, 0);
CreateScopes(scope); CreateScopes(scope);
auto step_scopes = GetStepScopes(scope); auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
......
...@@ -28,14 +28,15 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes, ...@@ -28,14 +28,15 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
const size_t seq_len, bool infer_shape_mode) { const size_t seq_len, bool infer_shape_mode) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided."); PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) { for (size_t i = 0; i < inlinks.size(); ++i) {
auto input_var = step_scopes[0]->FindVar(inlinks[i]); // global inputs
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.", auto input_var = step_scopes[0]->parent().FindVar(inlinks[i]);
inlinks[i]); PADDLE_ENFORCE_NOT_NULL(input_var, "input link [%s] is not in scope.",
inlinks[i]);
LoDTensor* input = input_var->GetMutable<LoDTensor>(); LoDTensor* input = input_var->GetMutable<LoDTensor>();
f::DDim dims = input->dims(); f::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len, PADDLE_ENFORCE_EQ(static_cast<size_t>(dims[0]), seq_len,
"all the inlinks must have same length"); "all the inlinks be the same length");
f::DDim step_dims = slice_ddim(dims, 1, dims.size()); f::DDim step_dims = slice_ddim(dims, 1, dims.size());
for (size_t j = 0; j < seq_len; j++) { for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input = Tensor* step_input =
...@@ -54,15 +55,14 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -54,15 +55,14 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks, const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode) { const size_t seq_len, bool infer_shape_mode) {
for (size_t i = 0; i < outlinks.size(); i++) { for (size_t i = 0; i < outlinks.size(); i++) {
auto output_var = step_scopes[0]->FindVar(outlinks[i]); auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.", PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
outlinks[i]); outlinks[i]);
LoDTensor* output = output_var->GetMutable<LoDTensor>(); LoDTensor* output = output_var->GetMutable<LoDTensor>();
if (infer_shape_mode) { if (infer_shape_mode) {
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal); auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
outlinks[i].internal);
f::DDim step_dims = f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims(); step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims); std::vector<int64_t> dims_vec = vectorize(step_dims);
......
...@@ -59,7 +59,6 @@ class PySimpleRNNTest(unittest.TestCase): ...@@ -59,7 +59,6 @@ class PySimpleRNNTest(unittest.TestCase):
def test_forward(self): def test_forward(self):
output = self.rnn.forward() output = self.rnn.forward()
print 'output', output
def create_tensor(scope, name, shape, np_data): def create_tensor(scope, name, shape, np_data):
...@@ -103,7 +102,7 @@ class TestRecurrentOp(unittest.TestCase): ...@@ -103,7 +102,7 @@ class TestRecurrentOp(unittest.TestCase):
ctx = core.DeviceContext.create(core.CPUPlace()) ctx = core.DeviceContext.create(core.CPUPlace())
self.rnnop.infer_shape(self.scope) self.rnnop.infer_shape(self.scope)
self.rnnop.run(self.scope, ctx) self.rnnop.run(self.scope, ctx)
return np.array(self.scope.find_var("h").get_tensor()) return np.array(self.scope.find_var("h@mem").get_tensor())
def create_global_variables(self): def create_global_variables(self):
# create inlink # create inlink
...@@ -123,7 +122,7 @@ class TestRecurrentOp(unittest.TestCase): ...@@ -123,7 +122,7 @@ class TestRecurrentOp(unittest.TestCase):
create_tensor(self.scope, "h_boot", [self.batch_size, self.input_dim], create_tensor(self.scope, "h_boot", [self.batch_size, self.input_dim],
h_boot_np_data) h_boot_np_data)
self.scope.new_var("step_scopes") self.scope.new_var("step_scopes")
self.scope.new_var("h") self.scope.new_var("h@mem")
def create_rnn_op(self): def create_rnn_op(self):
# create RNNOp # create RNNOp
...@@ -133,7 +132,7 @@ class TestRecurrentOp(unittest.TestCase): ...@@ -133,7 +132,7 @@ class TestRecurrentOp(unittest.TestCase):
boot_memories=["h_boot"], boot_memories=["h_boot"],
step_net="stepnet", step_net="stepnet",
# outputs # outputs
outlinks=["h"], outlinks=["h@mem"],
step_scopes="step_scopes", step_scopes="step_scopes",
# attributes # attributes
pre_memories=["h@pre"], pre_memories=["h@pre"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册