未验证 提交 8653cf30 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #10656 from reyoung/feature/support_op_role

Add `op_role` into OpDesc.
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -159,25 +160,39 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -159,25 +160,39 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) { if (static_cast<bool>(boost::get<int>(op->GetAttr(
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { OpProtoAndCheckerMaker::OpRoleAttrName())) &
switch (strategy_.reduce_) { static_cast<int>(OpRole::kBackward))) {
case BuildStrategy::ReduceStrategy::kReduce: try {
CreateReduceOp(&result, og, cur_device_id); auto backward_vars =
var_name_on_devices[cur_device_id].emplace(og); boost::get<std::vector<std::string>>(op->GetNullableAttr(
bcast_var_name_set[cur_device_id].emplace( OpProtoAndCheckerMaker::OpRoleVarAttrName()));
og.substr(0, og.size() - strlen(kGradVarSuffix)));
cur_device_id = (cur_device_id + 1) % places_.size(); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
break;
case BuildStrategy::ReduceStrategy::kAllReduce: for (size_t i = 0; i < backward_vars.size(); i += 2) {
if (IsSparseGradient(var_types, og)) { auto &p_name = backward_vars[i];
CreateReduceOp(&result, og, 0); auto &g_name = backward_vars[i + 1];
CreateBroadcastOp(&result, og, 0); VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
} else {
InsertNCCLAllReduceOp(&result, og); switch (strategy_.reduce_) {
} case BuildStrategy::ReduceStrategy::kReduce:
break; CreateReduceOp(&result, g_name, cur_device_id);
var_name_on_devices[cur_device_id].emplace(g_name);
bcast_var_name_set[cur_device_id].emplace(p_name);
cur_device_id = (cur_device_id + 1) % places_.size();
break;
case BuildStrategy::ReduceStrategy::kAllReduce:
if (IsSparseGradient(var_types, g_name)) {
CreateReduceOp(&result, g_name, 0);
CreateBroadcastOp(&result, g_name, 0);
} else {
InsertNCCLAllReduceOp(&result, g_name);
}
break;
}
} }
} catch (boost::bad_get e) {
} }
} }
} }
...@@ -398,11 +413,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, ...@@ -398,11 +413,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
// FIXME(yy): Do not hard code like this return boost::get<int>(
return op.OutputArgumentNames().size() == 1 && op.GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)) &&
!loss_var_name_.empty(); // If loss_var is empty. This is test mode
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -96,10 +96,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -96,10 +96,7 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
info->proto_ = new proto::OpProto; info->proto_ = new proto::OpProto;
info->checker_ = new OpAttrChecker(); info->checker_ = new OpAttrChecker();
T maker; T maker;
maker.SetProto(info->proto_); maker(info->proto_, info->checker_);
maker.SetChecker(info->checker_);
maker.Make();
maker.Validate();
info->proto_->set_type(op_type); info->proto_->set_type(op_type);
PADDLE_ENFORCE( PADDLE_ENFORCE(
info->proto_->IsInitialized(), info->proto_->IsInitialized(),
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
...@@ -222,6 +223,15 @@ Attribute OpDesc::GetAttr(const std::string &name) const { ...@@ -222,6 +223,15 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second; return it->second;
} }
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
auto it = attrs_.find(name);
if (it != attrs_.end()) {
return it->second;
} else {
return Attribute();
}
}
int OpDesc::GetBlockAttr(const std::string &name) const { int OpDesc::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
...@@ -249,6 +259,13 @@ void OpDesc::RenameOutput(const std::string &old_name, ...@@ -249,6 +259,13 @@ void OpDesc::RenameOutput(const std::string &old_name,
std::replace(output.second.begin(), output.second.end(), old_name, std::replace(output.second.begin(), output.second.end(), old_name,
new_name); new_name);
} }
auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
if (it != attrs_.end()) {
auto &op_vars = boost::get<std::vector<std::string>>(it->second);
std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
}
need_update_ = true; need_update_ = true;
} }
......
...@@ -78,6 +78,8 @@ class OpDesc { ...@@ -78,6 +78,8 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const;
void Rename(const std::string &old_name, const std::string &new_name); void Rename(const std::string &old_name, const std::string &new_name);
......
...@@ -13,6 +13,7 @@ limitations under the License. */ ...@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include <string> #include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -55,5 +56,28 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { ...@@ -55,5 +56,28 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
} }
} }
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
OpAttrChecker* attr_checker) {
proto_ = proto;
op_checker_ = attr_checker;
Make();
AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
{static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kNotSpecified)})
.SetDefault(static_cast<int>(OpRole::kNotSpecified));
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
"Optimized for variable")
.SetDefault({});
Validate();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,21 +20,31 @@ limitations under the License. */ ...@@ -20,21 +20,31 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum class OpRole {
kForward = 0x0000,
kBackward = 0x0001,
kOptimize = 0x0002,
kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
// CreateOp inside a operator.
kNotSpecified = 0x1000,
};
// this class not only make proto but also init attribute checkers. // this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker { class OpProtoAndCheckerMaker {
public: public:
static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
virtual void Make() = 0; virtual void Make() = 0;
virtual ~OpProtoAndCheckerMaker() { virtual ~OpProtoAndCheckerMaker() {
CHECK(validated_) << "should call Validate after build"; CHECK(validated_) << "should call Validate after build";
} }
void SetProto(proto::OpProto *proto) { proto_ = proto; }
void SetChecker(OpAttrChecker *attr_checker) { op_checker_ = attr_checker; }
void Validate();
protected: protected:
struct VariableBuilder { struct VariableBuilder {
proto::OpProto::Var *var_; proto::OpProto::Var *var_;
...@@ -76,6 +86,7 @@ class OpProtoAndCheckerMaker { ...@@ -76,6 +86,7 @@ class OpProtoAndCheckerMaker {
private: private:
void CheckNoDuplicatedInOutAttrs(); void CheckNoDuplicatedInOutAttrs();
void Validate();
proto::OpProto *proto_; proto::OpProto *proto_;
OpAttrChecker *op_checker_; OpAttrChecker *op_checker_;
......
...@@ -28,10 +28,8 @@ TEST(ProtoMaker, DuplicatedAttr) { ...@@ -28,10 +28,8 @@ TEST(ProtoMaker, DuplicatedAttr) {
paddle::framework::proto::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
TestAttrProtoMaker proto_maker; TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto); ASSERT_THROW(proto_maker(&op_proto, &op_checker),
proto_maker.SetChecker(&op_checker); paddle::platform::EnforceNotMet);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
...@@ -46,8 +44,6 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -46,8 +44,6 @@ TEST(ProtoMaker, DuplicatedInOut) {
paddle::framework::proto::OpProto op_proto; paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker; paddle::framework::OpAttrChecker op_checker;
TestAttrProtoMaker proto_maker; TestAttrProtoMaker proto_maker;
proto_maker.SetProto(&op_proto); ASSERT_THROW(proto_maker(&op_proto, &op_checker),
proto_maker.SetChecker(&op_checker); paddle::platform::EnforceNotMet);
proto_maker.Make();
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
} }
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include <paddle/fluid/framework/op_proto_maker.h>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -23,6 +24,21 @@ void BindConstValue(pybind11::module* m) { ...@@ -23,6 +24,21 @@ void BindConstValue(pybind11::module* m) {
m->def("kTempVarName", [] { return framework::kTempVarName; }); m->def("kTempVarName", [] { return framework::kTempVarName; });
m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; }); m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker");
pybind11::enum_<framework::OpRole>(op_proto_and_checker_maker, "OpRole")
.value("Forward", framework::OpRole::kForward)
.value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss);
op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
op_proto_and_checker_maker.def(
"kOpRoleVarAttrName",
framework::OpProtoAndCheckerMaker::OpRoleVarAttrName);
} }
} // namespace pybind } // namespace pybind
......
...@@ -51,6 +51,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -51,6 +51,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
op_desc.set_input(para, args) op_desc.set_input(para, args)
for para, args in outputs.iteritems(): for para, args in outputs.iteritems():
op_desc.set_output(para, args) op_desc.set_output(para, args)
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
if op_role_attr_name not in attrs:
attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in attrs.iteritems(): for name, val in attrs.iteritems():
if isinstance(val, framework.Block): if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc) op_desc.set_block_attr(name, val.desc)
...@@ -141,7 +147,7 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -141,7 +147,7 @@ def _addup_repetitive_outputs_(op_descs):
else: else:
if len(renamed_vars[var_name]) == 1: if len(renamed_vars[var_name]) == 1:
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
# rename original var_name # rename original var_name
renamed_vars[var_name][0] = new_name renamed_vars[var_name][0] = new_name
...@@ -149,7 +155,7 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -149,7 +155,7 @@ def _addup_repetitive_outputs_(op_descs):
_rename_arg_(pending_sum_ops, var_name, new_name) _rename_arg_(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
op_desc.rename_output(var_name, new_name) op_desc.rename_output(var_name, new_name)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
...@@ -335,9 +341,12 @@ def _append_backward_ops_(block, ...@@ -335,9 +341,12 @@ def _append_backward_ops_(block,
no_grad_dict[block.idx]) no_grad_dict[block.idx])
# append op_desc in grad_op_descs to target_block # append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
new_op_desc.set_attr(op_role_attr_name, backward)
grad_to_var["__current_op_desc__"] = new_op_desc grad_to_var["__current_op_desc__"] = new_op_desc
if callbacks is not None: if callbacks is not None:
assert (isinstance(callbacks, list)) assert (isinstance(callbacks, list))
...@@ -439,6 +448,22 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -439,6 +448,22 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
(list[(Variable,Variable)]): list of (parameter, gradient) pair. (list[(Variable,Variable)]): list of (parameter, gradient) pair.
""" """
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
if loss.op is None:
# the loss is from a cloned program. Find loss op manually.
for op in reversed(loss.block.ops):
assert isinstance(op, framework.Operator)
if len(op.output_arg_names) == 1 and op.output_arg_names[
0] == loss.name:
loss.op = op
break
if loss.op is None:
raise ValueError("loss.op is None. Should not happend")
loss.op.set_attr(core.op_proto_and_checker_maker.kOpRoleAttrName(),
int(core.op_proto_and_checker_maker.OpRole.Forward) |
int(core.op_proto_and_checker_maker.OpRole.Loss))
if callbacks is not None: if callbacks is not None:
isinstance(callbacks, list) isinstance(callbacks, list)
...@@ -456,12 +481,16 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -456,12 +481,16 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
current_block_idx = program.current_block_idx current_block_idx = program.current_block_idx
grad_to_var = dict() grad_to_var = dict()
op_desc = _create_op_desc_("fill_constant", {}, { op_desc = _create_op_desc_(
"Out": [_append_grad_suffix_(loss.name)] "fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, {
}, {"shape": [1], "shape": [1],
"value": 1.0, "value": 1.0,
"dtype": loss.dtype, "dtype": loss.dtype,
"force_cpu": False}) "force_cpu": False,
core.op_proto_and_checker_maker.kOpRoleAttrName():
int(core.op_proto_and_checker_maker.OpRole.Backward) |
int(core.op_proto_and_checker_maker.OpRole.Loss),
})
root_block.desc.append_op().copy_from(op_desc) root_block.desc.append_op().copy_from(op_desc)
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0])) block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))
...@@ -505,6 +534,24 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -505,6 +534,24 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
params_and_grads.append((param_var, grad_var)) params_and_grads.append((param_var, grad_var))
else: else:
params_and_grads.append((param_var, None)) params_and_grads.append((param_var, None))
op_role_var_attr_name = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
for p, g in params_and_grads:
if g is None:
continue
for op in reversed(program.global_block().ops):
assert isinstance(op, framework.Operator)
if g.name in op.output_arg_names:
g.op = op
break
if g.op is None:
raise ValueError("Unexpected branch")
attr_val = [p.name, g.name]
if g.op.has_attr(op_role_var_attr_name):
attr_val.extend(g.op.attr(op_role_var_attr_name))
g.op.set_attr(op_role_var_attr_name, attr_val)
return params_and_grads return params_and_grads
......
...@@ -214,21 +214,24 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -214,21 +214,24 @@ def set_gradient_clip(clip, param_list=None, program=None):
def append_gradient_clip_ops(param_grad): def append_gradient_clip_ops(param_grad):
context = dict() context = dict()
create_op_callbacks = []
for p, g in param_grad: for p, g in param_grad:
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr()) with p.block.program.optimized_guard(p):
if clip_attr is None: clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
clip_attr = NullGradientClipAttr() if clip_attr is None:
if not isinstance(clip_attr, BaseGradientClipAttr): clip_attr = NullGradientClipAttr()
raise TypeError( if not isinstance(clip_attr, BaseGradientClipAttr):
"clip attribute should be an instance of BaseGradientClipAttr") raise TypeError(
"clip attribute should be an instance of BaseGradientClipAttr"
)
clip_attr.process_context(context=context, param=p, grad=g) clip_attr.process_context(context=context, param=p, grad=g)
create_op_callbacks.append(
functools.partial( res = []
clip_attr.create_operators, param=p, grad=g)) for p, g in param_grad:
with p.block.program.optimized_guard(p):
res.append(clip_attr.create_operators(param=p, grad=g))
return [each_callback() for each_callback in create_op_callbacks] return res
ClipByValue = GradientClipByValue ClipByValue = GradientClipByValue
......
...@@ -404,6 +404,23 @@ class Operator(object): ...@@ -404,6 +404,23 @@ class Operator(object):
self.block = block self.block = block
self.desc = desc self.desc = desc
self.attrs = attrs self.attrs = attrs
if self.attrs is None:
self.attrs = dict()
del attrs
op_maker = core.op_proto_and_checker_maker
if op_maker.kOpRoleAttrName() not in self.attrs:
self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
role_var_name = op_maker.kOpRoleVarAttrName()
if len(self.block.program.
op_role_var) != 0 and role_var_name not in self.attrs:
self.attrs[role_var_name] = self.block.program.op_role_var
if role_var_name in self.attrs and len(self.attrs[role_var_name]) == 0:
del self.attrs[role_var_name]
if len(self.desc.type()) != 0: if len(self.desc.type()) != 0:
return return
if type is None: if type is None:
...@@ -469,22 +486,23 @@ class Operator(object): ...@@ -469,22 +486,23 @@ class Operator(object):
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
if attrs is not None: if self.attrs is not None:
if not isinstance(attrs, dict): if not isinstance(self.attrs, dict):
raise TypeError("'attrs' should be a dict.") raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs: for attr in proto.attrs:
attr_name = attr.name attr_name = attr.name
if (attr_name not in attrs) or (attrs[attr_name] is None): if (attr_name not in self.attrs) or (
self.attrs[attr_name] is None):
continue continue
if isinstance(attrs[attr_name], Block): if isinstance(self.attrs[attr_name], Block):
self.desc.set_block_attr(attr_name, attrs[attr_name].desc) self.desc.set_block_attr(attr_name,
elif isinstance(attrs[attr_name], core.BlockDesc) or \ self.attrs[attr_name].desc)
isinstance(attrs[attr_name], core.ProgramDesc): elif isinstance(self.attrs[attr_name], core.BlockDesc) or \
isinstance(self.attrs[attr_name], core.ProgramDesc):
self.desc.set_serialized_attr( self.desc.set_serialized_attr(
attr_name, attrs[attr_name].serialize_to_string()) attr_name, self.attrs[attr_name].serialize_to_string())
else: else:
self.desc.set_attr(attr_name, attrs[attr_name]) self.desc.set_attr(attr_name, self.attrs[attr_name])
self.desc.check_attrs() self.desc.check_attrs()
no_kernel_op_set = { no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'go', 'feed', 'fetch', 'save', 'load', 'recurrent', 'go',
...@@ -612,6 +630,10 @@ class Operator(object): ...@@ -612,6 +630,10 @@ class Operator(object):
""" """
return self.desc.attr_type(name) return self.desc.attr_type(name)
def set_attr(self, name, val):
self.attrs[name] = val
self.desc.set_attr(name, val)
@property @property
def attr_names(self): def attr_names(self):
""" """
...@@ -1002,6 +1024,33 @@ class Program(object): ...@@ -1002,6 +1024,33 @@ class Program(object):
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
self._seed = 0 self._seed = 0
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = []
@property
def op_role(self):
return self._current_role
@op_role.setter
def set_op_role(self, role):
self._current_role = role
@property
def op_role_var(self):
return self._op_role_var
@op_role_var.setter
def set_op_role_var(self, var_name):
self._op_role_var = [var_name]
@contextlib.contextmanager
def optimized_guard(self, var):
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Optimize
self._op_role_var = [var.name if isinstance(var, Variable) else var]
yield
self._op_role_var = []
self._current_role = OpRole.Forward
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
......
...@@ -213,11 +213,13 @@ class Optimizer(object): ...@@ -213,11 +213,13 @@ class Optimizer(object):
optimize_ops = [] optimize_ops = []
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
if param_and_grad[0].trainable is True and param_and_grad[ with param_and_grad[0].block.program.optimized_guard(
1] is not None: param_and_grad[0]):
optimize_op = self._append_optimize_op(loss.block, if param_and_grad[0].trainable is True and param_and_grad[
param_and_grad) 1] is not None:
optimize_ops.append(optimize_op) optimize_op = self._append_optimize_op(loss.block,
param_and_grad)
optimize_ops.append(optimize_op)
# Get custom finish ops for subclasses # Get custom finish ops for subclasses
# FIXME: Need to fix this once we figure out how to handle dependencies # FIXME: Need to fix this once we figure out how to handle dependencies
......
...@@ -43,31 +43,32 @@ def append_regularization_ops(parameters_and_grads, regularization=None): ...@@ -43,31 +43,32 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
""" """
params_and_grads = [] params_and_grads = []
for param, grad in parameters_and_grads: for param, grad in parameters_and_grads:
# If no gradient then we don't need to do anything with param.block.program.optimized_guard(param):
if grad is None: # If no gradient then we don't need to do anything
if grad is None:
params_and_grads.append((param, grad))
continue
regularization_term = None
if param.regularizer is not None:
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad, grad.block)
elif regularization is not None:
regularization_term = regularization(param, grad, grad.block)
# If no regularization specified, then we don't need to do anything
if regularization_term is None:
params_and_grads.append((param, grad))
continue
assert grad.shape == regularization_term.shape
grad.block.append_op(
type='elementwise_add',
inputs={"X": grad,
"Y": regularization_term},
outputs={"Out": grad})
params_and_grads.append((param, grad)) params_and_grads.append((param, grad))
continue
regularization_term = None
if param.regularizer is not None:
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad, grad.block)
elif regularization is not None:
regularization_term = regularization(param, grad, grad.block)
# If no regularization specified, then we don't need to do anything
if regularization_term is None:
params_and_grads.append((param, grad))
continue
assert grad.shape == regularization_term.shape
grad.block.append_op(
type='elementwise_add',
inputs={"X": grad,
"Y": regularization_term},
outputs={"Out": grad})
params_and_grads.append((param, grad))
return params_and_grads return params_and_grads
......
...@@ -36,6 +36,12 @@ def randomize_probability(batch_size, class_num, dtype='float32'): ...@@ -36,6 +36,12 @@ def randomize_probability(batch_size, class_num, dtype='float32'):
def create_op(scope, op_type, inputs, outputs, attrs): def create_op(scope, op_type, inputs, outputs, attrs):
kwargs = dict() kwargs = dict()
op_maker = core.op_proto_and_checker_maker
op_role_attr_name = op_maker.kOpRoleAttrName()
if op_role_attr_name not in attrs:
attrs[op_role_attr_name] = int(op_maker.OpRole.Forward)
def __create_var__(name, var_name): def __create_var__(name, var_name):
scope.var(var_name).get_tensor() scope.var(var_name).get_tensor()
kwargs[name].append(var_name) kwargs[name].append(var_name)
......
...@@ -63,7 +63,10 @@ class TestOperator(unittest.TestCase): ...@@ -63,7 +63,10 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual( self.assertEqual(
set(mul_op.attr_names), set(mul_op.attr_names),
set(["x_num_col_dims", "y_num_col_dims", "use_mkldnn"])) set([
"x_num_col_dims", "y_num_col_dims", "use_mkldnn", "op_role",
"op_role_var"
]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1) self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册