提交 2f56d4b3 编写于 作者: Y Yang Yang

forward pass compile time

上级 b2ee9190
......@@ -23,9 +23,11 @@ namespace operators {
constexpr char kInputs[] = "inputs";
constexpr char kParameters[] = "parameters";
constexpr char kPlaces[] = "places";
constexpr char kParallelBlock[] = "sub_block";
constexpr char kOutputs[] = "outputs";
constexpr char kParallelScopes[] = "sub_scopes";
constexpr char kParallelScopes[] = "parallel_scopes";
constexpr char kParallelBlock[] = "sub_block";
// #define GRAD_SUFFIX "@GRAD"
// constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX;
// constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX;
......
......@@ -424,7 +424,8 @@ class Operator(object):
self.desc.check_attrs()
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while'
'rnn_memory_helper_grad', 'conditional_block', 'while',
'parallel_do'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
......
......@@ -103,6 +103,7 @@ class ParallelDo(object):
def read_input(self, var):
self.inputs.append(var)
return var
def write_output(self, var):
self.outputs.append(var)
......@@ -149,7 +150,7 @@ class ParallelDo(object):
'places': self.places
},
outputs={'outputs': self.outputs,
'step_scopes': [step_scope]},
'parallel_scopes': [step_scope]},
attrs={'sub_block': current_block})
......
import unittest
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid as fluid
from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.backward import append_backward_ops
import numpy as np
import paddle.v2.fluid.core as core
class ParallelOpTest(unittest.TestCase):
def setUp(self):
x = layers.data(
shape=[2, 3, 4], dtype='float32', name='x', append_batch_size=False)
places = fluid.default_main_program().global_block().create_var()
pd = layers.ParallelDo(places=places)
with pd.do():
data = pd.read_input(x)
hidden = layers.fc(input=data, size=7)
pd.write_output(hidden)
data = pd()
print data
print fluid.default_main_program()
def test_forward(self):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册