提交 017bba16 编写于 作者: Y yuyang18

Add op role

上级 dfdcb7ea
...@@ -96,10 +96,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -96,10 +96,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info->proto_ = new proto::OpProto; info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
T maker; T maker;
maker.SetProto(info->proto_); maker(info->proto_, info->checker_);
maker.SetChecker(info->checker_);
maker.Make();
maker.Validate();
info->proto_->set_type(op_type); info->proto_->set_type(op_type);
PADDLE_ENFORCE( PADDLE_ENFORCE(
info->proto_->IsInitialized(), info->proto_->IsInitialized(),
......
...@@ -55,5 +55,25 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { ...@@ -55,5 +55,25 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
} }
} }
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
OpAttrChecker* attr_checker) {
proto_ = proto;
op_checker_ = attr_checker;
Make();
AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
{static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward)});
AddAttr<std::string>(OpRoleVarAttrName(), "Optimized for variable")
.SetDefault("");
Validate();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,21 +20,28 @@ limitations under the License. */ ...@@ -20,21 +20,28 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum class OpRole {
kForward = 0x0000,
kBackward = 0x0001,
kOptimize = 0x0002,
kLoss = 0x0100,
};
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker { class OpProtoAndCheckerMaker {
public: public:
static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
virtual void Make() = 0; virtual void Make() = 0;
virtual ~OpProtoAndCheckerMaker() { virtual ~OpProtoAndCheckerMaker() {
CHECK(validated_) << "should call Validate after build"; CHECK(validated_) << "should call Validate after build";
} }
void SetProto(proto::OpProto *proto) { proto_ = proto; }
void SetChecker(OpAttrChecker *attr_checker) { op_checker_ = attr_checker; }
void Validate();
protected: protected:
struct VariableBuilder { struct VariableBuilder {
proto::OpProto::Var *var_; proto::OpProto::Var *var_;
...@@ -76,6 +83,7 @@ class OpProtoAndCheckerMaker { ...@@ -76,6 +83,7 @@ class OpProtoAndCheckerMaker {
private: private:
void CheckNoDuplicatedInOutAttrs(); void CheckNoDuplicatedInOutAttrs();
void Validate();
proto::OpProto *proto_; proto::OpProto *proto_;
OpAttrChecker *op_checker_; OpAttrChecker *op_checker_;
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include <paddle/fluid/framework/op_proto_maker.h>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -23,6 +24,21 @@ void BindConstValue(pybind11::module* m) { ...@@ -23,6 +24,21 @@ void BindConstValue(pybind11::module* m) {
m->def("kTempVarName", [] { return framework::kTempVarName; }); m->def("kTempVarName", [] { return framework::kTempVarName; });
m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; }); m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker");
pybind11::enum_<framework::OpRole>(op_proto_and_checker_maker, "OpRole")
.value("Forward", framework::OpRole::kForward)
.value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss);
op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
op_proto_and_checker_maker.def(
"kOpRoleVarAttrName",
framework::OpProtoAndCheckerMaker::OpRoleVarAttrName);
} }
} // namespace pybind } // namespace pybind
......
...@@ -51,6 +51,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -51,6 +51,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
op_desc.set_input(para, args) op_desc.set_input(para, args)
for para, args in outputs.iteritems(): for para, args in outputs.iteritems():
op_desc.set_output(para, args) op_desc.set_output(para, args)
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
if op_role_attr_name not in attrs:
attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in attrs.iteritems(): for name, val in attrs.iteritems():
if isinstance(val, framework.Block): if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc) op_desc.set_block_attr(name, val.desc)
...@@ -335,9 +341,12 @@ def _append_backward_ops_(block, ...@@ -335,9 +341,12 @@ def _append_backward_ops_(block,
no_grad_dict[block.idx]) no_grad_dict[block.idx])
# append op_desc in grad_op_descs to target_block # append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
new_op_desc.set_attr(op_role_attr_name, backward)
grad_to_var["__current_op_desc__"] = new_op_desc grad_to_var["__current_op_desc__"] = new_op_desc
if callbacks is not None: if callbacks is not None:
assert (isinstance(callbacks, list)) assert (isinstance(callbacks, list))
...@@ -439,6 +448,11 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -439,6 +448,11 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
(list[(Variable,Variable)]): list of (parameter, gradient) pair. (list[(Variable,Variable)]): list of (parameter, gradient) pair.
""" """
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
loss.op.set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(),
int(core.op_proto_and_checker_maker.OpRole.Forward) |
int(core.op_proto_and_checker_maker.OpRole.Loss))
if callbacks is not None: if callbacks is not None:
isinstance(callbacks, list) isinstance(callbacks, list)
...@@ -456,12 +470,16 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -456,12 +470,16 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
current_block_idx = program.current_block_idx current_block_idx = program.current_block_idx
grad_to_var = dict() grad_to_var = dict()
op_desc = _create_op_desc_("fill_constant", {}, { op_desc = _create_op_desc_(
"Out": [_append_grad_suffix_(loss.name)] "fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, {
}, {"shape": [1], "shape": [1],
"value": 1.0, "value": 1.0,
"dtype": loss.dtype, "dtype": loss.dtype,
"force_cpu": False}) "force_cpu": False,
core.op_proto_and_checker_maker.kOpRoleAttrName():
int(core.op_proto_and_checker_maker.OpRole.Backward) |
int(core.op_proto_and_checker_maker.OpRole.Loss),
})
root_block.desc.append_op().copy_from(op_desc) root_block.desc.append_op().copy_from(op_desc)
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
...@@ -503,6 +521,21 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -503,6 +521,21 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
params_and_grads.append((param_var, grad_var)) params_and_grads.append((param_var, grad_var))
else: else:
params_and_grads.append((param_var, None)) params_and_grads.append((param_var, None))
op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
for p, g in params_and_grads:
if g is None:
continue
for op in reversed(program.global_block().ops):
assert isinstance(op, framework.Operator)
if g.name in op.output_arg_names:
g.op = op
break
if g.op is None:
raise ValueError("Unexpected branch")
g.op.set_attr(op_role_var_attr_name, p.name)
return params_and_grads return params_and_grads
......
...@@ -214,21 +214,24 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -214,21 +214,24 @@ def set_gradient_clip(clip, param_list=None, program=None):
def append_gradient_clip_ops(param_grad): def append_gradient_clip_ops(param_grad):
context = dict() context = dict()
create_op_callbacks = []
for p, g in param_grad: for p, g in param_grad:
with p.block.program.optimized_guard(p):
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr()) clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
if clip_attr is None: if clip_attr is None:
clip_attr = NullGradientClipAttr() clip_attr = NullGradientClipAttr()
if not isinstance(clip_attr, BaseGradientClipAttr): if not isinstance(clip_attr, BaseGradientClipAttr):
raise TypeError( raise TypeError(
"clip attribute should be an instance of BaseGradientClipAttr") "clip attribute should be an instance of BaseGradientClipAttr"
)
clip_attr.process_context(context=context, param=p, grad=g) clip_attr.process_context(context=context, param=p, grad=g)
create_op_callbacks.append(
functools.partial(
clip_attr.create_operators, param=p, grad=g))
return [each_callback() for each_callback in create_op_callbacks] res = []
for p, g in param_grad:
with p.block.program.optimized_guard(p):
res.append(clip_attr.create_operators(param=p, grad=g))
return res
ClipByValue = GradientClipByValue ClipByValue = GradientClipByValue
......
...@@ -402,6 +402,19 @@ class Operator(object): ...@@ -402,6 +402,19 @@ class Operator(object):
self.block = block self.block = block
self.desc = desc self.desc = desc
self.attrs = attrs self.attrs = attrs
if self.attrs is None:
self.attrs = dict()
del attrs
op_maker = core.op_proto_and_checker_maker
if op_maker.kOpRoleAttrName() not in self.attrs:
self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
if len(self.block.program.op_role_var
) != 0 and op_maker.kOpRoleVarAttrName() not in self.attrs:
self.attrs[op_maker.kOpRoleVarAttrName(
)] = self.block.program.op_role_var
if len(self.desc.type()) != 0: if len(self.desc.type()) != 0:
return return
if type is None: if type is None:
...@@ -467,21 +480,23 @@ class Operator(object): ...@@ -467,21 +480,23 @@ class Operator(object):
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
if attrs is not None: if self.attrs is not None:
if not isinstance(attrs, dict): if not isinstance(self.attrs, dict):
raise TypeError("'attrs' should be a dict.") raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs: for attr in proto.attrs:
attr_name = attr.name attr_name = attr.name
if (attr_name not in attrs) or (attrs[attr_name] is None): if (attr_name not in self.attrs) or (
self.attrs[attr_name] is None):
continue continue
if isinstance(attrs[attr_name], Block): if isinstance(self.attrs[attr_name], Block):
self.desc.set_block_attr(attr_name, attrs[attr_name].desc) self.desc.set_block_attr(attr_name,
elif isinstance(attrs[attr_name], core.BlockDesc) or \ self.attrs[attr_name].desc)
isinstance(attrs[attr_name], core.ProgramDesc): elif isinstance(self.attrs[attr_name], core.BlockDesc) or \
isinstance(self.attrs[attr_name], core.ProgramDesc):
self.desc.set_serialized_attr( self.desc.set_serialized_attr(
attr_name, attrs[attr_name].serialize_to_string()) attr_name, self.attrs[attr_name].serialize_to_string())
else: else:
self.desc.set_attr(attr_name, attrs[attr_name]) self.desc.set_attr(attr_name, self.attrs[attr_name])
self.desc.check_attrs() self.desc.check_attrs()
no_kernel_op_set = { no_kernel_op_set = {
...@@ -610,6 +625,10 @@ class Operator(object): ...@@ -610,6 +625,10 @@ class Operator(object):
""" """
return self.desc.attr_type(name) return self.desc.attr_type(name)
def set_attr(self, name, val):
self.attrs[name] = val
self.desc.set_attr(name, val)
@property @property
def attr_names(self): def attr_names(self):
""" """
...@@ -1000,6 +1019,33 @@ class Program(object): ...@@ -1000,6 +1019,33 @@ class Program(object):
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
self._seed = 0 self._seed = 0
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = ""
@property
def op_role(self):
return self._current_role
@op_role.setter
def set_op_role(self, role):
self._current_role = role
@property
def op_role_var(self):
return self._op_role_var
@op_role_var.setter
def set_op_role_var(self, var_name):
self._op_role_var = var_name
@contextlib.contextmanager
def optimized_guard(self, var):
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Optimize
self._op_role_var = var.name if isinstance(var, Variable) else var
yield
self._op_role_var = ""
self._current_role = OpRole.Forward
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
......
...@@ -213,6 +213,8 @@ class Optimizer(object): ...@@ -213,6 +213,8 @@ class Optimizer(object):
optimize_ops = [] optimize_ops = []
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
with param_and_grad[0].block.program.optimized_guard(
param_and_grad[0]):
if param_and_grad[0].trainable is True and param_and_grad[ if param_and_grad[0].trainable is True and param_and_grad[
1] is not None: 1] is not None:
optimize_op = self._append_optimize_op(loss.block, optimize_op = self._append_optimize_op(loss.block,
......
...@@ -43,6 +43,7 @@ def append_regularization_ops(parameters_and_grads, regularization=None): ...@@ -43,6 +43,7 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
""" """
params_and_grads = [] params_and_grads = []
for param, grad in parameters_and_grads: for param, grad in parameters_and_grads:
with param.block.program.optimized_guard(param):
# If no gradient then we don't need to do anything # If no gradient then we don't need to do anything
if grad is None: if grad is None:
params_and_grads.append((param, grad)) params_and_grads.append((param, grad))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册