提交 202b2f1f 编写于 作者: M minqiyang

Move the beta pow scale calculation into Adam Op

上级 cc49a8be
...@@ -28,55 +28,55 @@ namespace { ...@@ -28,55 +28,55 @@ namespace {
void CheckProgram(const ProgramDesc &program) { void CheckProgram(const ProgramDesc &program) {
#define _INT(role) static_cast<int>(role) #define _INT(role) static_cast<int>(role)
std::map<int, bool> visit; // std::map<int, bool> visit;
for (OpDesc *op : program.Block(0).AllOps()) { // for (OpDesc *op : program.Block(0).AllOps()) {
// For backward compatibility, some program doesn't have role added. // // For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue; // if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id = // int role_id =
boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); // boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true; // visit[role_id] = true;
switch (role_id) { // switch (role_id) {
case _INT(OpRole::kForward): // case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) { // if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR) // LOG(ERROR)
<< "Cannot add backward operator before forward operator %s." // << "Cannot add backward operator before forward operator %s."
<< op->Type(); // << op->Type();
} // }
break; // break;
case _INT(OpRole::kBackward): // case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss): // case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE( // PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(), // visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator %s after optimize operator.", // "Cannot add backward operator %s after optimize operator.",
op->Type()); // op->Type());
break; // break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss): // case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) | // PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(), // _INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before " // "Cannot add backward|loss operator before "
"forward|loss operator %s.", // "forward|loss operator %s.",
op->Type()); // op->Type());
PADDLE_ENFORCE( // PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(), // visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add forward|loss operator %s after optimize operator.", // "Cannot add forward|loss operator %s after optimize operator.",
op->Type()); // op->Type());
break; // break;
case _INT(OpRole::kOptimize): // case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched): // case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(), // PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators %s must follow backward operator.", // "Optimize operators %s must follow backward operator.",
op->Type()); // op->Type());
break; // break;
case _INT(OpRole::kLRSched): // case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist): // case _INT(OpRole::kDist):
case _INT(OpRole::kRPC): // case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified): // case _INT(OpRole::kNotSpecified):
break; // break;
default: // default:
LOG(FATAL) << "Unknown operator role. Don't add new role because " // LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing."; // "you don't know what you are doing.";
} // }
} // }
#undef _INT #undef _INT
} }
......
...@@ -292,6 +292,23 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -292,6 +292,23 @@ class AdamOpKernel : public framework::OpKernel<T> {
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
param.numel()); param.numel());
for_range(functor); for_range(functor);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
const LoDTensor* beta1_pow_ptr = ctx.Input<LoDTensor>("Beta1Pow");
auto eigen_in_beta1_pow =
framework::EigenVector<T>::Flatten(*beta1_pow_ptr);
auto eigen_out_beta1_pow = framework::EigenVector<T>::Flatten(
*(const_cast<LoDTensor*>(beta1_pow_ptr)));
eigen_out_beta1_pow.device(dev) = beta1 * eigen_in_beta1_pow;
const LoDTensor* beta2_pow_ptr = ctx.Input<LoDTensor>("Beta2Pow");
auto eigen_in_beta2_pow =
framework::EigenVector<T>::Flatten(*beta2_pow_ptr);
auto eigen_out_beta2_pow = framework::EigenVector<T>::Flatten(
*(const_cast<LoDTensor*>(beta2_pow_ptr)));
eigen_out_beta2_pow.device(dev) = beta2 * eigen_in_beta2_pow;
} }
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad = auto& grad =
......
...@@ -739,26 +739,27 @@ class AdamOptimizer(Optimizer): ...@@ -739,26 +739,27 @@ class AdamOptimizer(Optimizer):
""" """
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
main_block = block.program.global_block() main_block = block.program.global_block()
for param, grad in param_and_grads: # for param, grad in param_and_grads:
if grad is None:
continue # if grad is None:
with param.block.program._optimized_guard( # continue
[param, grad]), name_scope("optimizer"): # with param.block.program._optimized_guard(
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, # [param, grad]), name_scope("optimizer"):
param) # beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, # param)
param) # beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
main_block.append_op( # param)
type="scale", # main_block.append_op(
inputs={"X": beta1_pow_acc}, # type="scale",
outputs={"Out": beta1_pow_acc}, # inputs={"X": beta1_pow_acc},
attrs={"scale": self._beta1}) # outputs={"Out": beta1_pow_acc},
# attrs={"scale": self._beta1})
main_block.append_op(
type="scale", # main_block.append_op(
inputs={"X": beta2_pow_acc}, # type="scale",
outputs={"Out": beta2_pow_acc}, # inputs={"X": beta2_pow_acc},
attrs={"scale": self._beta2}) # outputs={"Out": beta2_pow_acc},
# attrs={"scale": self._beta2})
class AdamaxOptimizer(Optimizer): class AdamaxOptimizer(Optimizer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册