提交 017bba16 编写于 作者: Y yuyang18

Add op role

上级 dfdcb7ea
......@@ -96,10 +96,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker();
T maker;
maker.SetProto(info->proto_);
maker.SetChecker(info->checker_);
maker.Make();
maker.Validate();
maker(info->proto_, info->checker_);
info->proto_->set_type(op_type);
PADDLE_ENFORCE(
info->proto_->IsInitialized(),
......
......@@ -55,5 +55,25 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
}
}
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
OpAttrChecker* attr_checker) {
proto_ = proto;
op_checker_ = attr_checker;
Make();
AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
{static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward)});
AddAttr<std::string>(OpRoleVarAttrName(), "Optimized for variable")
.SetDefault("");
Validate();
}
} // namespace framework
} // namespace paddle
......@@ -20,21 +20,28 @@ limitations under the License. */
namespace paddle {
namespace framework {
enum class OpRole {
kForward = 0x0000,
kBackward = 0x0001,
kOptimize = 0x0002,
kLoss = 0x0100,
};
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
virtual void Make() = 0;
virtual ~OpProtoAndCheckerMaker() {
CHECK(validated_) << "should call Validate after build";
}
void SetProto(proto::OpProto *proto) { proto_ = proto; }
void SetChecker(OpAttrChecker *attr_checker) { op_checker_ = attr_checker; }
void Validate();
protected:
struct VariableBuilder {
proto::OpProto::Var *var_;
......@@ -76,6 +83,7 @@ class OpProtoAndCheckerMaker {
private:
void CheckNoDuplicatedInOutAttrs();
void Validate();
proto::OpProto *proto_;
OpAttrChecker *op_checker_;
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/pybind/const_value.h"
#include <paddle/fluid/framework/op_proto_maker.h>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
......@@ -23,6 +24,21 @@ void BindConstValue(pybind11::module* m) {
m->def("kTempVarName", [] { return framework::kTempVarName; });
m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker");
pybind11::enum_<framework::OpRole>(op_proto_and_checker_maker, "OpRole")
.value("Forward", framework::OpRole::kForward)
.value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss);
op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
op_proto_and_checker_maker.def(
"kOpRoleVarAttrName",
framework::OpProtoAndCheckerMaker::OpRoleVarAttrName);
}
} // namespace pybind
......
......@@ -51,6 +51,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
op_desc.set_input(para, args)
for para, args in outputs.iteritems():
op_desc.set_output(para, args)
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
if op_role_attr_name not in attrs:
attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in attrs.iteritems():
if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc)
......@@ -335,9 +341,12 @@ def _append_backward_ops_(block,
no_grad_dict[block.idx])
# append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
new_op_desc.set_attr(op_role_attr_name, backward)
grad_to_var["__current_op_desc__"] = new_op_desc
if callbacks is not None:
assert (isinstance(callbacks, list))
......@@ -439,6 +448,11 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
"""
assert isinstance(loss, framework.Variable)
loss.op.set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(),
int(core.op_proto_and_checker_maker.OpRole.Forward) |
int(core.op_proto_and_checker_maker.OpRole.Loss))
if callbacks is not None:
isinstance(callbacks, list)
......@@ -456,12 +470,16 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
current_block_idx = program.current_block_idx
grad_to_var = dict()
op_desc = _create_op_desc_("fill_constant", {}, {
"Out": [_append_grad_suffix_(loss.name)]
}, {"shape": [1],
op_desc = _create_op_desc_(
"fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, {
"shape": [1],
"value": 1.0,
"dtype": loss.dtype,
"force_cpu": False})
"force_cpu": False,
core.op_proto_and_checker_maker.kOpRoleAttrName():
int(core.op_proto_and_checker_maker.OpRole.Backward) |
int(core.op_proto_and_checker_maker.OpRole.Loss),
})
root_block.desc.append_op().copy_from(op_desc)
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
......@@ -503,6 +521,21 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
params_and_grads.append((param_var, grad_var))
else:
params_and_grads.append((param_var, None))
op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
for p, g in params_and_grads:
if g is None:
continue
for op in reversed(program.global_block().ops):
assert isinstance(op, framework.Operator)
if g.name in op.output_arg_names:
g.op = op
break
if g.op is None:
raise ValueError("Unexpected branch")
g.op.set_attr(op_role_var_attr_name, p.name)
return params_and_grads
......
......@@ -214,21 +214,24 @@ def set_gradient_clip(clip, param_list=None, program=None):
def append_gradient_clip_ops(param_grad):
context = dict()
create_op_callbacks = []
for p, g in param_grad:
with p.block.program.optimized_guard(p):
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
if clip_attr is None:
clip_attr = NullGradientClipAttr()
if not isinstance(clip_attr, BaseGradientClipAttr):
raise TypeError(
"clip attribute should be an instance of BaseGradientClipAttr")
"clip attribute should be an instance of BaseGradientClipAttr"
)
clip_attr.process_context(context=context, param=p, grad=g)
create_op_callbacks.append(
functools.partial(
clip_attr.create_operators, param=p, grad=g))
return [each_callback() for each_callback in create_op_callbacks]
res = []
for p, g in param_grad:
with p.block.program.optimized_guard(p):
res.append(clip_attr.create_operators(param=p, grad=g))
return res
ClipByValue = GradientClipByValue
......
......@@ -402,6 +402,19 @@ class Operator(object):
self.block = block
self.desc = desc
self.attrs = attrs
if self.attrs is None:
self.attrs = dict()
del attrs
op_maker = core.op_proto_and_checker_maker
if op_maker.kOpRoleAttrName() not in self.attrs:
self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
if len(self.block.program.op_role_var
) != 0 and op_maker.kOpRoleVarAttrName() not in self.attrs:
self.attrs[op_maker.kOpRoleVarAttrName(
)] = self.block.program.op_role_var
if len(self.desc.type()) != 0:
return
if type is None:
......@@ -467,21 +480,23 @@ class Operator(object):
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
if attrs is not None:
if not isinstance(attrs, dict):
if self.attrs is not None:
if not isinstance(self.attrs, dict):
raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs:
attr_name = attr.name
if (attr_name not in attrs) or (attrs[attr_name] is None):
if (attr_name not in self.attrs) or (
self.attrs[attr_name] is None):
continue
if isinstance(attrs[attr_name], Block):
self.desc.set_block_attr(attr_name, attrs[attr_name].desc)
elif isinstance(attrs[attr_name], core.BlockDesc) or \
isinstance(attrs[attr_name], core.ProgramDesc):
if isinstance(self.attrs[attr_name], Block):
self.desc.set_block_attr(attr_name,
self.attrs[attr_name].desc)
elif isinstance(self.attrs[attr_name], core.BlockDesc) or \
isinstance(self.attrs[attr_name], core.ProgramDesc):
self.desc.set_serialized_attr(
attr_name, attrs[attr_name].serialize_to_string())
attr_name, self.attrs[attr_name].serialize_to_string())
else:
self.desc.set_attr(attr_name, attrs[attr_name])
self.desc.set_attr(attr_name, self.attrs[attr_name])
self.desc.check_attrs()
no_kernel_op_set = {
......@@ -610,6 +625,10 @@ class Operator(object):
"""
return self.desc.attr_type(name)
def set_attr(self, name, val):
self.attrs[name] = val
self.desc.set_attr(name, val)
@property
def attr_names(self):
"""
......@@ -1000,6 +1019,33 @@ class Program(object):
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
self._seed = 0
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = ""
@property
def op_role(self):
return self._current_role
@op_role.setter
def set_op_role(self, role):
self._current_role = role
@property
def op_role_var(self):
return self._op_role_var
@op_role_var.setter
def set_op_role_var(self, var_name):
self._op_role_var = var_name
@contextlib.contextmanager
def optimized_guard(self, var):
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Optimize
self._op_role_var = var.name if isinstance(var, Variable) else var
yield
self._op_role_var = ""
self._current_role = OpRole.Forward
def __str__(self):
return self.to_string(True)
......
......@@ -213,6 +213,8 @@ class Optimizer(object):
optimize_ops = []
for param_and_grad in parameters_and_grads:
with param_and_grad[0].block.program.optimized_guard(
param_and_grad[0]):
if param_and_grad[0].trainable is True and param_and_grad[
1] is not None:
optimize_op = self._append_optimize_op(loss.block,
......
......@@ -43,6 +43,7 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
"""
params_and_grads = []
for param, grad in parameters_and_grads:
with param.block.program.optimized_guard(param):
# If no gradient then we don't need to do anything
if grad is None:
params_and_grads.append((param, grad))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册