提交 82a22d32 编写于 作者: Y Yang Yu

Update code

上级 978d1288
......@@ -66,14 +66,6 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), pdesc.Size());
auto& block = pdesc.Block(block_id);
if (VLOG_IS_ON(100)) {
std::ostringstream sout;
for (auto& name : scope->GetAllNames(false)) {
sout << name << ", ";
}
VLOG(100) << "Scope has variable " << sout.str();
}
Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
......
......@@ -134,6 +134,14 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
#endif
offset_ = 0;
}
if (typeid(float).hash_code() == type.hash_code()) {
auto buf = reinterpret_cast<float*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
for (int64_t i = 0; i < this->numel(); ++i) {
buf[i] = NAN;
}
}
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
......
......@@ -107,10 +107,12 @@ class SumKernel : public framework::OpKernel<T> {
out_array.resize(i + 1);
}
if (out_array[i].numel() == 0) {
VLOG(10) << context.op().Output("Out") << " just copy";
framework::CopyFrom(in_array[i], in_array[i].place(),
context.device_context(), &out_array[i]);
out_array[i].set_lod(in_array[i].lod());
} else {
VLOG(10) << context.op().Output("Out") << " merged";
PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod());
auto in = EigenVector<T>::Flatten(in_array[i]);
auto result = EigenVector<T>::Flatten(out_array[i]);
......
import numpy as np
import contextlib
from framework import Program, default_main_program
from . import core
from framework import Program, default_main_program, Parameter, Variable
__all__ = ['Executor', 'g_scope']
__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope']
g_scope = core.Scope()
def global_scope():
return g_scope
def switch_scope(scope):
global g_scope
ex = g_scope
g_scope = scope
return ex
@contextlib.contextmanager
def scope_guard(scope):
ex = switch_scope(scope)
yield
switch_scope(ex)
def as_numpy(tensor):
if isinstance(tensor, list):
return [as_numpy(t) for t in tensor]
......@@ -117,7 +136,7 @@ class Executor(object):
raise TypeError()
if scope is None:
scope = g_scope
scope = global_scope()
program = program.clone()
global_block = program.global_block()
......
......@@ -170,7 +170,7 @@ def main():
exe.run(fluid.default_startup_program())
embedding_param = fluid.g_scope.find_var(embedding_name).get_tensor()
embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor()
embedding_param.set(
load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place)
......
......@@ -19,8 +19,10 @@ def prog_scope():
def __fn__(*args, **kwargs):
prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(prog, startup_prog):
fn(*args, **kwargs)
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
fn(*args, **kwargs)
return __fn__
......
......@@ -298,7 +298,6 @@ class TestSimpleMulWithMemory(unittest.TestCase):
@prog_scope()
def test_forward_backward(self):
py_rnn = TestSimpleMulWithMemory.SimpleMulWithMemory()
data = fluid.layers.data(
name=self.DATA_NAME, shape=[self.DATA_WIDTH], lod_level=1)
data.stop_gradient = False
......@@ -323,19 +322,18 @@ class TestSimpleMulWithMemory(unittest.TestCase):
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
feed = py_rnn.to_feed(cpu)
for _ in xrange(2):
last_np, w_g, i_g = map(numpy.array,
exe.run(feed=feed,
fetch_list=[
last, self.PARAM_NAME + "@GRAD",
self.DATA_NAME + "@GRAD"
],
return_numpy=False))
last_np, w_g, i_g = map(numpy.array,
exe.run(feed=feed,
fetch_list=[
last, self.PARAM_NAME + "@GRAD",
self.DATA_NAME + "@GRAD"
],
return_numpy=False))
last_by_py, = py_rnn.exe().values()
self.assertTrue(numpy.allclose(last_np, last_by_py))
w_g_num = py_rnn.get_numeric_gradient_of_param(self.PARAM_NAME)
print w_g[0], w_g_num[0]
# print w_g_num[0], w_g[0]
self.assertTrue(numpy.allclose(w_g_num, w_g, rtol=0.1))
i_g_num = py_rnn.get_numeric_gradient_of_input(self.DATA_NAME)
i_g_num = i_g_num.reshape(i_g.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册