提交 44c52a8c 编写于 作者: Y yuyang18

Polish op_proto_maker

上级 577c19b2
...@@ -163,8 +163,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -163,8 +163,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { if (static_cast<bool>(boost::get<int>(op->GetAttr(
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) {
auto &backward_vars = boost::get<std::vector<std::string>>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
for (auto &og : backward_vars) {
if (balance_parameter_opt_between_cards_) { if (balance_parameter_opt_between_cards_) {
CreateReduceOp(&result, og, cur_device_id); CreateReduceOp(&result, og, cur_device_id);
var_name_on_devices[cur_device_id].emplace(og); var_name_on_devices[cur_device_id].emplace(og);
...@@ -399,11 +404,11 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, ...@@ -399,11 +404,11 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this return boost::get<int>(
return op.OutputArgumentNames().size() == 1 && op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss));
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
...@@ -249,6 +250,13 @@ void OpDesc::RenameOutput(const std::string &old_name, ...@@ -249,6 +250,13 @@ void OpDesc::RenameOutput(const std::string &old_name,
std::replace(output.second.begin(), output.second.end(), old_name, std::replace(output.second.begin(), output.second.end(), old_name,
new_name); new_name);
} }
auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
if (it != attrs_.end()) {
auto &op_vars = boost::get<std::vector<std::string>>(it->second);
std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
}
need_update_ = true; need_update_ = true;
} }
......
...@@ -13,6 +13,7 @@ limitations under the License. */ ...@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include <string> #include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -69,8 +70,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -69,8 +70,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward), static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward)}); static_cast<int>(OpRole::kBackward)});
AddAttr<std::string>(OpRoleVarAttrName(), "Optimized for variable") AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
.SetDefault(""); "Optimized for variable")
.SetDefault({});
Validate(); Validate();
} }
......
...@@ -534,7 +534,10 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -534,7 +534,10 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
if g.op is None: if g.op is None:
raise ValueError("Unexpected branch") raise ValueError("Unexpected branch")
g.op.set_attr(op_role_var_attr_name, p.name) attr_val = [p.name]
if g.op.has_attr(op_role_var_attr_name):
attr_val.extend(g.op.attr(op_role_var_attr_name))
g.op.set_attr(op_role_var_attr_name, attr_val)
return params_and_grads return params_and_grads
......
...@@ -410,10 +410,14 @@ class Operator(object): ...@@ -410,10 +410,14 @@ class Operator(object):
if op_maker.kOpRoleAttrName() not in self.attrs: if op_maker.kOpRoleAttrName() not in self.attrs:
self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
if len(self.block.program.op_role_var
) != 0 and op_maker.kOpRoleVarAttrName() not in self.attrs: role_var_name = op_maker.kOpRoleVarAttrName()
self.attrs[op_maker.kOpRoleVarAttrName( if len(self.block.program.
)] = self.block.program.op_role_var op_role_var) != 0 and role_var_name not in self.attrs:
self.attrs[role_var_name] = self.block.program.op_role_var
if role_var_name in self.attrs and len(self.attrs[role_var_name]) == 0:
del self.attrs[role_var_name]
if len(self.desc.type()) != 0: if len(self.desc.type()) != 0:
return return
...@@ -497,7 +501,6 @@ class Operator(object): ...@@ -497,7 +501,6 @@ class Operator(object):
attr_name, self.attrs[attr_name].serialize_to_string()) attr_name, self.attrs[attr_name].serialize_to_string())
else: else:
self.desc.set_attr(attr_name, self.attrs[attr_name]) self.desc.set_attr(attr_name, self.attrs[attr_name])
self.desc.check_attrs() self.desc.check_attrs()
no_kernel_op_set = { no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'go', 'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
...@@ -1020,7 +1023,7 @@ class Program(object): ...@@ -1020,7 +1023,7 @@ class Program(object):
self.current_block_idx = 0 self.current_block_idx = 0
self._seed = 0 self._seed = 0
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = "" self._op_role_var = []
@property @property
def op_role(self): def op_role(self):
...@@ -1036,15 +1039,15 @@ class Program(object): ...@@ -1036,15 +1039,15 @@ class Program(object):
@op_role_var.setter @op_role_var.setter
def set_op_role_var(self, var_name): def set_op_role_var(self, var_name):
self._op_role_var = var_name self._op_role_var = [var_name]
@contextlib.contextmanager @contextlib.contextmanager
def optimized_guard(self, var): def optimized_guard(self, var):
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Optimize self._current_role = OpRole.Optimize
self._op_role_var = var.name if isinstance(var, Variable) else var self._op_role_var = [var.name if isinstance(var, Variable) else var]
yield yield
self._op_role_var = "" self._op_role_var = []
self._current_role = OpRole.Forward self._current_role = OpRole.Forward
def __str__(self): def __str__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册