提交 b97d61ad 编写于 作者: Q qijun

merge baidu/develop

...@@ -8,23 +8,24 @@ ...@@ -8,23 +8,24 @@
- cpu MHz : 2101.000 - cpu MHz : 2101.000
- cache size : 20480 KB - cache size : 20480 KB
### Blas settings
Setting environment variable: `MKL_NUM_THREADS=1`.
### Single Node Single Thread ### Single Node Single Thread
- PServer Count: 10
- Trainer Count: 20
- Metrics: samples / sec - Metrics: samples / sec
| Batch Size | 32 | 64 | 128 | 256 | | Batch Size | 32 | 64 | 128 | 256 |
| -- | -- | -- | -- | -- | | -- | -- | -- | -- | -- |
| PaddlePaddle Fluid | 15.44 | 16.32 | 16.74 | 16.79 | | PaddlePaddle Fluid | 15.44 | 16.32 | 16.74 | 16.79 |
| PaddlePaddle v2 | 15.97 | 17.04 | 17.60 | 17.83 | | PaddlePaddle v2 | 15.97 | 17.04 | 17.60 | 17.83 |
| TensorFlow | - | - | - | - | | TensorFlow | 9.09 | 9.10 | 9.24 | 8.66 |
### Different Batch Size ### Different Batch Size
- PServer Count: 10 - PServer Count: 10
- Trainer Count: 20 - Trainer Count: 20
- Per trainer CPU Core: 1
- Metrics: samples / sec - Metrics: samples / sec
| Batch Size | 32 | 64 | 128 | 256 | | Batch Size | 32 | 64 | 128 | 256 |
......
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include <queue>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -64,12 +66,36 @@ VarDesc *BlockDesc::RenameVar(const std::string &old_name, ...@@ -64,12 +66,36 @@ VarDesc *BlockDesc::RenameVar(const std::string &old_name,
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr; if (name == kEmptyVarName) return nullptr;
auto it = vars_.find(name); std::queue<const BlockDesc *> frontier;
if (it == vars_.end()) { std::unordered_set<const BlockDesc *> visited;
return Parent() == kNoneBlockIndex ? nullptr
: ParentBlock()->FindVarRecursive(name); frontier.push(this);
while (!frontier.empty()) { // BFS
auto cur = frontier.front();
frontier.pop();
if (visited.count(cur) != 0) {
continue;
}
auto var = cur->FindVar(name);
if (var != nullptr) {
return var;
}
auto fwd = cur->ForwardBlock();
auto parent = cur->ParentBlock();
if (fwd != nullptr) {
frontier.push(fwd);
}
if (parent != nullptr) {
frontier.push(parent);
}
visited.insert(cur);
} }
return it->second.get();
return nullptr;
} }
VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) { VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
...@@ -155,10 +181,7 @@ void BlockDesc::Flush() { ...@@ -155,10 +181,7 @@ void BlockDesc::Flush() {
} }
BlockDesc *BlockDesc::ParentBlock() const { BlockDesc *BlockDesc::ParentBlock() const {
if (this->desc_->parent_idx() == kNoneBlockIndex) { return prog_->MutableBlock(static_cast<size_t>(desc_->parent_idx()));
return nullptr;
}
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
} }
proto::BlockDesc *BlockDesc::Proto() { proto::BlockDesc *BlockDesc::Proto() {
...@@ -205,5 +228,16 @@ void BlockDesc::ClearPBVars() { ...@@ -205,5 +228,16 @@ void BlockDesc::ClearPBVars() {
} }
} }
void BlockDesc::SetForwardBlockID(int32_t forward_block_id) {
PADDLE_ENFORCE(!desc_->has_forward_block_idx(),
"Parent block ID has been set to %d. Cannot set to %d",
desc_->forward_block_idx(), forward_block_id);
desc_->set_forward_block_idx(forward_block_id);
}
BlockDesc *BlockDesc::ForwardBlock() const {
return prog_->MutableBlock(static_cast<size_t>(desc_->forward_block_idx()));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -49,6 +49,8 @@ class BlockDesc { ...@@ -49,6 +49,8 @@ class BlockDesc {
int32_t Parent() const { return desc_->parent_idx(); } int32_t Parent() const { return desc_->parent_idx(); }
int32_t ForwardBlockID() const { return desc_->forward_block_idx(); }
VarDesc *Var(const std::string &name_bytes); VarDesc *Var(const std::string &name_bytes);
VarDesc *FindVar(const std::string &name_bytes) const; VarDesc *FindVar(const std::string &name_bytes) const;
...@@ -75,6 +77,10 @@ class BlockDesc { ...@@ -75,6 +77,10 @@ class BlockDesc {
BlockDesc *ParentBlock() const; BlockDesc *ParentBlock() const;
BlockDesc *ForwardBlock() const;
void SetForwardBlockID(int32_t forward_block_id);
OpDesc *AppendOp(); OpDesc *AppendOp();
void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc); void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
...@@ -93,7 +99,7 @@ class BlockDesc { ...@@ -93,7 +99,7 @@ class BlockDesc {
proto::BlockDesc *Proto(); proto::BlockDesc *Proto();
ProgramDesc *Program() { return this->prog_; } ProgramDesc *Program() const { return this->prog_; }
private: private:
void ClearPBOps(); void ClearPBOps();
......
...@@ -100,8 +100,7 @@ class ChannelHolder { ...@@ -100,8 +100,7 @@ class ChannelHolder {
virtual ~Placeholder() {} virtual ~Placeholder() {}
virtual const std::type_index Type() const = 0; virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0; virtual void* Ptr() const = 0;
virtual void Close() const = 0; virtual void Close() = 0;
std::type_info type_;
}; };
template <typename T> template <typename T>
...@@ -116,7 +115,7 @@ class ChannelHolder { ...@@ -116,7 +115,7 @@ class ChannelHolder {
if (channel_) channel_->Close(); if (channel_) channel_->Close();
} }
std::unique_ptr<Channel<T>*> channel_; std::unique_ptr<Channel<T>> channel_;
const std::type_index type_; const std::type_index type_;
}; };
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "gtest/gtest.h" #include "gtest/gtest.h"
using paddle::framework::Channel; using paddle::framework::Channel;
using paddle::framework::ChannelHolder;
using paddle::framework::MakeChannel; using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel; using paddle::framework::CloseChannel;
using paddle::framework::details::Buffered; using paddle::framework::details::Buffered;
...@@ -508,3 +509,36 @@ TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) { ...@@ -508,3 +509,36 @@ TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) {
auto ch = MakeChannel<int>(0); auto ch = MakeChannel<int>(0);
ChannelDestroyUnblockSenders(ch); ChannelDestroyUnblockSenders(ch);
} }
void ChannelHolderSendReceive(ChannelHolder *ch) {
unsigned sum_send = 0;
std::thread t([&]() {
for (int i = 0; i < 5; i++) {
EXPECT_EQ(ch->Send(&i), true);
sum_send += i;
}
});
for (int i = 0; i < 5; i++) {
int recv;
EXPECT_EQ(ch->Receive(&recv), true);
EXPECT_EQ(recv, i);
}
ch->close();
t.join();
EXPECT_EQ(sum_send, 10U);
}
TEST(ChannelHolder, ChannelHolderBufferedSendReceiveTest) {
ChannelHolder *ch = new ChannelHolder();
ch->Reset<int>(10);
ChannelHolderSendReceive(ch);
delete ch;
}
TEST(ChannelHolder, ChannelHolderUnBufferedSendReceiveTest) {
ChannelHolder *ch = new ChannelHolder();
ch->Reset<int>(0);
ChannelHolderSendReceive(ch);
delete ch;
}
...@@ -158,6 +158,7 @@ message BlockDesc { ...@@ -158,6 +158,7 @@ message BlockDesc {
required int32 parent_idx = 2; required int32 parent_idx = 2;
repeated VarDesc vars = 3; repeated VarDesc vars = 3;
repeated OpDesc ops = 4; repeated OpDesc ops = 4;
optional int32 forward_block_idx = 5 [ default = -1 ];
} }
// Please refer to // Please refer to
......
...@@ -38,7 +38,13 @@ class ProgramDesc { ...@@ -38,7 +38,13 @@ class ProgramDesc {
BlockDesc *AppendBlock(const BlockDesc &parent); BlockDesc *AppendBlock(const BlockDesc &parent);
BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); } BlockDesc *MutableBlock(size_t idx) {
if (idx == static_cast<size_t>(kNoneBlockIndex)) {
return nullptr;
} else {
return blocks_[idx].get();
}
}
const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; } const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; }
......
...@@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
while_grad->SetInput(kStepScopes, Output(kStepScopes)); while_grad->SetInput(kStepScopes, Output(kStepScopes));
auto *grad_block = this->grad_block_[0]; auto *grad_block = this->grad_block_[0];
auto *fwd_block = grad_block->ParentBlock(); auto *fwd_block = grad_block->ForwardBlock();
auto *parent_block = grad_block->ParentBlock();
// Not all of IGs will be generated by inner gradient operators of while op. // Not all of IGs will be generated by inner gradient operators of while op.
// Ignore IGs that is not generated by the inside block. // Ignore IGs that is not generated by the inside block.
...@@ -260,33 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -260,33 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
for (auto &o : Output(kOutputs)) { for (auto &o : Output(kOutputs)) {
block_ins.insert(o); block_ins.insert(o);
} }
std::unordered_set<std::string> extra_inputs; std::unordered_set<std::string> output_grads;
for (const auto *op : grad_block->AllOps()) { for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : op->InputArgumentNames()) { for (auto &input_name : op->InputArgumentNames()) {
// If the input of Op has been recorded or is generated by the forward // If the input of Op has been recorded or is generated by the forward
// block, do not make it as input again. // block, do not make it as input again.
// The input is located in I/O or other op's outputs or the variable is
// located in grad_block's parents
if (block_ins.find(input_name) != block_ins.end() || if (block_ins.find(input_name) != block_ins.end() ||
fwd_block->FindVar(input_name) != nullptr) { (fwd_block->FindVarRecursive(input_name) != nullptr ||
parent_block->FindVarRecursive(input_name) != nullptr)) {
continue; continue;
} }
extra_inputs.insert(input_name); output_grads.insert(input_name);
} }
for (auto &output_name : op->OutputArgumentNames()) { for (auto &output_name : op->OutputArgumentNames()) {
block_ins.insert(output_name); block_ins.insert(output_name);
} }
} }
std::vector<std::string> extra_inputs_list; std::vector<std::string> output_grads_list;
extra_inputs_list.resize(extra_inputs.size()); output_grads_list.resize(output_grads.size());
std::copy(extra_inputs.begin(), extra_inputs.end(), std::copy(output_grads.begin(), output_grads.end(),
extra_inputs_list.begin()); output_grads_list.begin());
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
while_grad->SetAttrMap(this->Attrs()); while_grad->SetAttrMap(this->Attrs());
while_grad->SetBlockAttr(kStepBlock, *grad_block); while_grad->SetBlockAttr(kStepBlock, *grad_block);
// record the original output gradient names, since the gradient name of // record the original output gradient names, since the gradient name of
// while operator could be renamed. // while operator could be renamed.
while_grad->SetAttr("original_output_grad", extra_inputs_list); while_grad->SetAttr("original_output_grad", output_grads_list);
return std::unique_ptr<framework::OpDesc>(while_grad); return std::unique_ptr<framework::OpDesc>(while_grad);
} }
......
...@@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) { ...@@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) {
py::class_<BlockDesc>(m, "BlockDesc", "") py::class_<BlockDesc>(m, "BlockDesc", "")
.def_property_readonly("id", &BlockDesc::ID) .def_property_readonly("id", &BlockDesc::ID)
.def_property_readonly("parent", &BlockDesc::Parent) .def_property_readonly("parent", &BlockDesc::Parent)
.def("get_forward_block_idx", &BlockDesc::ForwardBlockID)
.def("set_forward_block_idx", &BlockDesc::SetForwardBlockID)
.def("append_op", &BlockDesc::AppendOp, .def("append_op", &BlockDesc::AppendOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("prepend_op", &BlockDesc::PrependOp, .def("prepend_op", &BlockDesc::PrependOp,
......
...@@ -298,7 +298,8 @@ def _append_backward_ops_(block, ...@@ -298,7 +298,8 @@ def _append_backward_ops_(block,
# If the op has its own sub-block, deal with the sub-block first # If the op has its own sub-block, deal with the sub-block first
if op.has_attr("sub_block"): if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block")) sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx) grad_sub_block = program.create_block()
grad_sub_block.set_forward_block_idx(sub_block.idx)
cb = _callback_lookup_(op) cb = _callback_lookup_(op)
if cb is not None: if cb is not None:
if callbacks is None: if callbacks is None:
...@@ -310,6 +311,8 @@ def _append_backward_ops_(block, ...@@ -310,6 +311,8 @@ def _append_backward_ops_(block,
else: else:
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, callbacks) no_grad_dict, grad_to_var, callbacks)
program.rollback()
grad_sub_block_list.append(grad_sub_block.desc) grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
......
...@@ -152,7 +152,7 @@ class Variable(object): ...@@ -152,7 +152,7 @@ class Variable(object):
shape(tuple|list|None): The shape of variable. -1 means the batch size. shape(tuple|list|None): The shape of variable. -1 means the batch size.
Some kinds of variable do not contain shape, just set it to None. Some kinds of variable do not contain shape, just set it to None.
dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable. dtype(np.dtype|core.VarDesc.VarType|str): The data type of variable.
lod_level(int): The level of lod tensor. 0 means there is not a time lod_level(int): The level of lod tensor. 0 means it is not a time
series data. series data.
persistable(bool): True if the variable should be saved as check point. persistable(bool): True if the variable should be saved as check point.
Defaults to False. Defaults to False.
...@@ -346,7 +346,7 @@ class OpProtoHolder(object): ...@@ -346,7 +346,7 @@ class OpProtoHolder(object):
def __init__(self): def __init__(self):
assert not hasattr( assert not hasattr(
self.__class__, self.__class__,
'_instance'), 'Please use `instance()` to get OpProtoHolder opject!' '_instance'), 'Please use `instance()` to get OpProtoHolder object!'
op_protos = get_all_op_protos() op_protos = get_all_op_protos()
self.op_proto_map = {} self.op_proto_map = {}
for proto in op_protos: for proto in op_protos:
...@@ -368,8 +368,8 @@ class OpProtoHolder(object): ...@@ -368,8 +368,8 @@ class OpProtoHolder(object):
class Operator(object): class Operator(object):
""" """
Python Operator class. The operator represents the build in instructs in a Python Operator class. The operator represents the build in instructions in a
Block. Users can use the build in instructs to describe their neural Block. Users can use the build in instructions to describe their neural
network. network.
""" """
...@@ -478,7 +478,7 @@ class Operator(object): ...@@ -478,7 +478,7 @@ class Operator(object):
raise TypeError("'attrs' should be a dict.") raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs: for attr in proto.attrs:
attr_name = attr.name attr_name = attr.name
if (not attr_name in attrs) or (attrs[attr_name] is None): if (attr_name not in attrs) or (attrs[attr_name] is None):
continue continue
if isinstance(attrs[attr_name], Block): if isinstance(attrs[attr_name], Block):
self.desc.set_block_attr(attr_name, attrs[attr_name].desc) self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
...@@ -696,6 +696,13 @@ class Block(object): ...@@ -696,6 +696,13 @@ class Block(object):
def parent_idx(self): def parent_idx(self):
return self.desc.parent return self.desc.parent
@property
def forward_block_idx(self):
return self.desc.get_forward_block_idx()
def set_forward_block_idx(self, idx):
self.desc.set_forward_block_idx(idx)
@property @property
def idx(self): def idx(self):
return self.desc.id return self.desc.id
...@@ -709,15 +716,32 @@ class Block(object): ...@@ -709,15 +716,32 @@ class Block(object):
return v return v
def var_recursive(self, name): def var_recursive(self, name):
if self.has_var(name): frontier = list()
return self.var(name) visited = set()
else:
if self.idx == 0: frontier.append(self)
raise ValueError("var %s is not in block(%d) nor its parents." %
name, self.idx) prog = self.program
else:
parent_block = self.program.block(self.parent_idx) while len(frontier) != 0: # BFS
return parent_block.var_recursive(name) cur = frontier[0]
frontier = frontier[1:]
if id(cur) in visited:
continue
if cur.has_var(name):
return cur.var(name)
if cur.parent_idx != -1:
frontier.append(prog.block(cur.parent_idx))
if cur.forward_block_idx != -1:
frontier.append(prog.block(cur.forward_block_idx))
visited.add(id(cur))
raise ValueError("Var {0} is not found recursively".format(name))
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())
...@@ -727,7 +751,7 @@ class Block(object): ...@@ -727,7 +751,7 @@ class Block(object):
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
var = Variable(self, *args, **kwargs) var = Variable(block=self, *args, **kwargs)
if 'initializer' in kwargs: if 'initializer' in kwargs:
kwargs['initializer'](var, self) kwargs['initializer'](var, self)
return var return var
...@@ -798,13 +822,13 @@ class Block(object): ...@@ -798,13 +822,13 @@ class Block(object):
def append_op(self, *args, **kwargs): def append_op(self, *args, **kwargs):
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
op = Operator(self, op_desc, *args, **kwargs) op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.append(op) self.ops.append(op)
return op return op
def delete_ops(self, ops): def delete_ops(self, ops):
# remove from cpp # remove from cpp
# FIXME(typhoonzero): remove only the first occuracy. # FIXME(typhoonzero): remove only the first occurrence.
try: try:
start = list(self.ops).index(ops[0]) start = list(self.ops).index(ops[0])
end = list(self.ops).index(ops[-1]) end = list(self.ops).index(ops[-1])
...@@ -822,6 +846,11 @@ class Block(object): ...@@ -822,6 +846,11 @@ class Block(object):
return op return op
def sync_with_cpp(self): def sync_with_cpp(self):
"""
Sync with the desc on the c++ end.
This method is used to synchronize the c++ desc instance generated by backward.
"""
# sync variables from cpp # sync variables from cpp
for var in self.desc.all_vars(): for var in self.desc.all_vars():
if not self.has_var(var.name()): if not self.has_var(var.name()):
...@@ -867,9 +896,9 @@ class Block(object): ...@@ -867,9 +896,9 @@ class Block(object):
def copy_param_info_from(self, other): def copy_param_info_from(self, other):
""" """
Copy the information of parameters from other block Copy the information of parameters from the other block
Args: Args:
other(Block): other block other(Block): the other block
Returns: Returns:
None None
...@@ -1215,6 +1244,6 @@ def get_var(name, program=None): ...@@ -1215,6 +1244,6 @@ def get_var(name, program=None):
if program is None: if program is None:
program = default_main_program() program = default_main_program()
assert isinstance(name, str) assert isinstance(name, str)
assert isinstance(name, Program) assert isinstance(program, Program)
return program.global_block().var(name) return program.global_block().var(name)
...@@ -104,7 +104,7 @@ def fc(input, ...@@ -104,7 +104,7 @@ def fc(input,
* :math:`X_i`: The input tensor. * :math:`X_i`: The input tensor.
* :math:`W`: The weights created by this layer. * :math:`W`: The weights created by this layer.
* :math:`b`: The bias parameter created by this layer (if needed). * :math:`b`: The bias parameter created by this layer (if needed).
* :math:`Act`: The activation funtion. * :math:`Act`: The activation function.
* :math:`Out`: The output tensor. * :math:`Out`: The output tensor.
Args: Args:
......
...@@ -220,15 +220,15 @@ def _process_sub_block_pair(pdesc, sub_block_pair): ...@@ -220,15 +220,15 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
# Find fwd_op/bwd_op block pair # Find fwd_op/bwd_op block pair
for grad_id in grad_sub_block_ids: for grad_id in grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent fwd_id = pdesc.block(grad_id).get_forward_block_idx()
if parent_id in sub_block_ids: if fwd_id in sub_block_ids:
sub_block_id_pair.append((parent_id, grad_id)) sub_block_id_pair.append((fwd_id, grad_id))
sub_block_ids.remove(parent_id) sub_block_ids.remove(fwd_id)
# Get fwd_op/bwd_op block ops # Get fwd_op/bwd_op block ops
for parent_id, grad_id in sub_block_id_pair: for fwd_id, grad_id in sub_block_id_pair:
sub_block_ops = [] sub_block_ops = []
sub_block = pdesc.block(parent_id) sub_block = pdesc.block(fwd_id)
block_op_size = sub_block.op_size() block_op_size = sub_block.op_size()
for i in range(block_op_size): for i in range(block_op_size):
sub_block_ops.append(sub_block.op(i)) sub_block_ops.append(sub_block.op(i))
...@@ -239,19 +239,19 @@ def _process_sub_block_pair(pdesc, sub_block_pair): ...@@ -239,19 +239,19 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
sub_block_ops.append(grad_sub_block.op(i)) sub_block_ops.append(grad_sub_block.op(i))
sub_op_output = set() sub_op_output = set()
sub_op_output.update(sub_op_dict[parent_id].output_arg_names()) sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
sub_op_output.update(sub_op_dict[grad_id].output_arg_names()) sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
ops_list.append((sub_block_ops, block_op_size, sub_op_output)) ops_list.append((sub_block_ops, block_op_size, sub_op_output))
# Process rest fwd_op block ops # Process rest fwd_op block ops
for parent_id in sub_block_ids: for fwd_id in sub_block_ids:
sub_block_ops = [] sub_block_ops = []
sub_block = pdesc.block(parent_id) sub_block = pdesc.block(fwd_id)
sub_block_op_size = sub_block.op_size() sub_block_op_size = sub_block.op_size()
for i in range(sub_block_op_size): for i in range(sub_block_op_size):
sub_block_ops.append(sub_block.op(i)) sub_block_ops.append(sub_block.op(i))
sub_op_output = set() sub_op_output = set()
sub_op_output.update(sub_op_dict[parent_id].output_arg_names()) sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output)) ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
return ops_list return ops_list
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册