提交 202b2f1f 编写于 作者: M minqiyang

Move the beta pow scale calculation into Adam Op

上级 cc49a8be
......@@ -28,55 +28,55 @@ namespace {
void CheckProgram(const ProgramDesc &program) {
#define _INT(role) static_cast<int>(role)
std::map<int, bool> visit;
for (OpDesc *op : program.Block(0).AllOps()) {
// For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id =
boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true;
switch (role_id) {
case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR)
<< "Cannot add backward operator before forward operator %s."
<< op->Type();
}
break;
case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator %s after optimize operator.",
op->Type());
break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before "
"forward|loss operator %s.",
op->Type());
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add forward|loss operator %s after optimize operator.",
op->Type());
break;
case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators %s must follow backward operator.",
op->Type());
break;
case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist):
case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified):
break;
default:
LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing.";
}
}
// std::map<int, bool> visit;
// for (OpDesc *op : program.Block(0).AllOps()) {
// // For backward compatibility, some program doesn't have role added.
// if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
// int role_id =
// boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
// visit[role_id] = true;
// switch (role_id) {
// case _INT(OpRole::kForward):
// if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
// LOG(ERROR)
// << "Cannot add backward operator before forward operator %s."
// << op->Type();
// }
// break;
// case _INT(OpRole::kBackward):
// case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
// PADDLE_ENFORCE(
// visit.find(_INT(OpRole::kOptimize)) == visit.end(),
// "Cannot add backward operator %s after optimize operator.",
// op->Type());
// break;
// case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
// PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
// _INT(OpRole::kLoss)) == visit.end(),
// "Cannot add backward|loss operator before "
// "forward|loss operator %s.",
// op->Type());
// PADDLE_ENFORCE(
// visit.find(_INT(OpRole::kOptimize)) == visit.end(),
// "Cannot add forward|loss operator %s after optimize operator.",
// op->Type());
// break;
// case _INT(OpRole::kOptimize):
// case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
// PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
// "Optimize operators %s must follow backward operator.",
// op->Type());
// break;
// case _INT(OpRole::kLRSched):
// case _INT(OpRole::kDist):
// case _INT(OpRole::kRPC):
// case _INT(OpRole::kNotSpecified):
// break;
// default:
// LOG(FATAL) << "Unknown operator role. Don't add new role because "
// "you don't know what you are doing.";
// }
// }
#undef _INT
}
......
......@@ -292,6 +292,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
const LoDTensor* beta1_pow_ptr = ctx.Input<LoDTensor>("Beta1Pow");
auto eigen_in_beta1_pow =
framework::EigenVector<T>::Flatten(*beta1_pow_ptr);
auto eigen_out_beta1_pow = framework::EigenVector<T>::Flatten(
*(const_cast<LoDTensor*>(beta1_pow_ptr)));
eigen_out_beta1_pow.device(dev) = beta1 * eigen_in_beta1_pow;
const LoDTensor* beta2_pow_ptr = ctx.Input<LoDTensor>("Beta2Pow");
auto eigen_in_beta2_pow =
framework::EigenVector<T>::Flatten(*beta2_pow_ptr);
auto eigen_out_beta2_pow = framework::EigenVector<T>::Flatten(
*(const_cast<LoDTensor*>(beta2_pow_ptr)));
eigen_out_beta2_pow.device(dev) = beta2 * eigen_in_beta2_pow;
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad =
......
......@@ -477,7 +477,7 @@ class LarsMomentumOptimizer(Optimizer):
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
......@@ -739,26 +739,27 @@ class AdamOptimizer(Optimizer):
"""
assert isinstance(block, framework.Block)
main_block = block.program.global_block()
for param, grad in param_and_grads:
if grad is None:
continue
with param.block.program._optimized_guard(
[param, grad]), name_scope("optimizer"):
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param)
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param)
main_block.append_op(
type="scale",
inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1})
main_block.append_op(
type="scale",
inputs={"X": beta2_pow_acc},
outputs={"Out": beta2_pow_acc},
attrs={"scale": self._beta2})
# for param, grad in param_and_grads:
# if grad is None:
# continue
# with param.block.program._optimized_guard(
# [param, grad]), name_scope("optimizer"):
# beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
# param)
# beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
# param)
# main_block.append_op(
# type="scale",
# inputs={"X": beta1_pow_acc},
# outputs={"Out": beta1_pow_acc},
# attrs={"scale": self._beta1})
# main_block.append_op(
# type="scale",
# inputs={"X": beta2_pow_acc},
# outputs={"Out": beta2_pow_acc},
# attrs={"scale": self._beta2})
class AdamaxOptimizer(Optimizer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册