提交 44c52a8c 编写于 作者: Y yuyang18

Polish op_proto_maker

上级 577c19b2
无相关合并请求
......@@ -163,8 +163,13 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
if (static_cast<bool>(boost::get<int>(op->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward))) {
auto &backward_vars = boost::get<std::vector<std::string>>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
for (auto &og : backward_vars) {
if (balance_parameter_opt_between_cards_) {
CreateReduceOp(&result, og, cur_device_id);
var_name_on_devices[cur_device_id].emplace(og);
......@@ -399,11 +404,11 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this
return op.OutputArgumentNames().size() == 1 &&
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
return boost::get<int>(
op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
(static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss));
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h"
......@@ -249,6 +250,13 @@ void OpDesc::RenameOutput(const std::string &old_name,
std::replace(output.second.begin(), output.second.end(), old_name,
new_name);
}
auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
if (it != attrs_.end()) {
auto &op_vars = boost::get<std::vector<std::string>>(it->second);
std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
}
need_update_ = true;
}
......
......@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include <string>
#include <vector>
namespace paddle {
namespace framework {
......@@ -69,8 +70,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward)});
AddAttr<std::string>(OpRoleVarAttrName(), "Optimized for variable")
.SetDefault("");
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
"Optimized for variable")
.SetDefault({});
Validate();
}
......
......@@ -534,7 +534,10 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
if g.op is None:
raise ValueError("Unexpected branch")
g.op.set_attr(op_role_var_attr_name, p.name)
attr_val = [p.name]
if g.op.has_attr(op_role_var_attr_name):
attr_val.extend(g.op.attr(op_role_var_attr_name))
g.op.set_attr(op_role_var_attr_name, attr_val)
return params_and_grads
......
......@@ -410,10 +410,14 @@ class Operator(object):
if op_maker.kOpRoleAttrName() not in self.attrs:
self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
if len(self.block.program.op_role_var
) != 0 and op_maker.kOpRoleVarAttrName() not in self.attrs:
self.attrs[op_maker.kOpRoleVarAttrName(
)] = self.block.program.op_role_var
role_var_name = op_maker.kOpRoleVarAttrName()
if len(self.block.program.
op_role_var) != 0 and role_var_name not in self.attrs:
self.attrs[role_var_name] = self.block.program.op_role_var
if role_var_name in self.attrs and len(self.attrs[role_var_name]) == 0:
del self.attrs[role_var_name]
if len(self.desc.type()) != 0:
return
......@@ -497,7 +501,6 @@ class Operator(object):
attr_name, self.attrs[attr_name].serialize_to_string())
else:
self.desc.set_attr(attr_name, self.attrs[attr_name])
self.desc.check_attrs()
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
......@@ -1020,7 +1023,7 @@ class Program(object):
self.current_block_idx = 0
self._seed = 0
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = ""
self._op_role_var = []
@property
def op_role(self):
......@@ -1036,15 +1039,15 @@ class Program(object):
@op_role_var.setter
def set_op_role_var(self, var_name):
self._op_role_var = var_name
self._op_role_var = [var_name]
@contextlib.contextmanager
def optimized_guard(self, var):
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Optimize
self._op_role_var = var.name if isinstance(var, Variable) else var
self._op_role_var = [var.name if isinstance(var, Variable) else var]
yield
self._op_role_var = ""
self._op_role_var = []
self._current_role = OpRole.Forward
def __str__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部