提交 e412b1ae 编写于 作者: J JiayiFeng

Merge branch 'unify_executor_interface' into add_parallel_executor_tests

......@@ -57,7 +57,7 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# specify sphinx version as 1.5.6 and remove -U option for [pip install -U
# sphinx-rtd-theme] since -U option will cause sphinx being updated to newest
# version(1.7.1 for now), which causes building documentation failed.
RUN pip install --upgrade pip && \
RUN pip install --upgrade pip==9.0.3 && \
pip install -U wheel && \
pip install -U docopt PyYAML sphinx==1.5.6 && \
pip install sphinx-rtd-theme==0.1.9 recommonmark
......
......@@ -4,6 +4,7 @@
.. toctree::
:maxdepth: 1
api_doc_std_cn.md
new_op_cn.md
new_op_kernel.md
use_eigen_cn.md
......
......@@ -4,6 +4,7 @@ Development
.. toctree::
:maxdepth: 1
api_doc_std_en.md
new_op_en.md
new_op_kernel.md
use_eigen_en.md
......
......@@ -55,21 +55,21 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
const OpDesc &op,
const platform::Place &p,
const size_t &i) const {
auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p);
auto var_names = op->InputArgumentNames();
auto var_names = op.InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
var_names = op.OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
......@@ -107,7 +107,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(&result, op, p, 0);
CreateOpHandleIOs(&result, *op, p, 0);
continue;
}
......@@ -117,7 +117,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
CreateOpHandleIOs(&result, op, p, i);
CreateOpHandleIOs(&result, *op, p, i);
auto var_names = op->OutputArgumentNames();
......
......@@ -45,8 +45,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i) const;
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
const platform::Place &p, const size_t &i) const;
private:
std::string loss_var_name_;
......
......@@ -169,7 +169,7 @@ class Accuracy(MetricBase):
return self.value / self.weight
class ChunkEvalutor(MetricBase):
class ChunkEvaluator(MetricBase):
"""
Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter
......@@ -177,7 +177,7 @@ class ChunkEvalutor(MetricBase):
"""
def __init__(self, name=None):
super(ChunkEvalutor, self).__init__(name)
super(ChunkEvaluator, self).__init__(name)
self.num_infer_chunks = 0
self.num_label_chunks = 0
self.num_correct_chunks = 0
......
......@@ -61,8 +61,8 @@ class ParallelExecutor(object):
main_program=test_program,
share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed_dict=feed_dict)
test_loss, = test_exe.run([loss.name], feed_dict=feed_dict)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
test_loss, = test_exe.run([loss.name], feed=feed_dict)
"""
self._places = []
......@@ -123,22 +123,23 @@ class ParallelExecutor(object):
allow_op_delay)
self.scope = scope
def run(self, fetch_list, feed_dict={}):
def run(self, fetch_list, feed={}, feed_dict={}):
"""
:param fetch_list: A list of variable names that will be fetched.
:param feed_dict: A dict mapping for feed variable name to LoDTensor
:param feed: A dict mapping for feed variable name to LoDTensor
or numpy array.
:return: fetched value list.
"""
if not isinstance(feed_dict, dict):
raise TypeError("feed_dict should be a dict")
feed = feed_dict
if not isinstance(feed, dict):
raise TypeError("feed should be a dict")
feed_tensor_dict = {}
for i, feed_name in enumerate(feed_dict):
feed_tensor = feed_dict[feed_name]
for i, feed_name in enumerate(feed):
feed_tensor = feed[feed_name]
if not isinstance(feed_tensor, core.LoDTensor):
feed_tensor = core.LoDTensor()
feed_tensor.set(feed_dict[feed_name], self._act_places[0])
feed_tensor.set(feed[feed_name], self._act_places[0])
feed_tensor_dict[feed_name] = feed_tensor
fetch_var_name = '@FETCHED_VAR_NAME@'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册