提交 2ecc5622 编写于 作者: X Xin Pan 提交者: whs

small AverageOptimizer enhance. (#11761)

* small AverageOptimizer enhance.

* clean

* clean
上级 19e877ff
......@@ -19,28 +19,28 @@ namespace operators {
template <>
void GetAccumulators<paddle::platform::CPUDeviceContext>(
const framework::ExecutionContext& ctx, int64_t* num_updates_,
int64_t* num_accumulates_, int64_t* old_num_accumulates_) {
const framework::ExecutionContext& ctx, int64_t* num_updates,
int64_t* num_accumulates, int64_t* old_num_accumulates) {
auto* in_old_num_accumulates = ctx.Input<Tensor>("in_old_num_accumulates");
auto* in_num_accumulates = ctx.Input<Tensor>("in_num_accumulates");
auto* in_num_updates = ctx.Input<Tensor>("in_num_updates");
*old_num_accumulates_ = in_old_num_accumulates->data<int64_t>()[0];
*num_accumulates_ = in_num_accumulates->data<int64_t>()[0];
*num_updates_ = in_num_updates->data<int64_t>()[0];
*old_num_accumulates = in_old_num_accumulates->data<int64_t>()[0];
*num_accumulates = in_num_accumulates->data<int64_t>()[0];
*num_updates = in_num_updates->data<int64_t>()[0];
}
template <>
void SetAccumulators<paddle::platform::CPUDeviceContext>(
const framework::ExecutionContext& ctx, int64_t num_updates_,
int64_t num_accumulates_, int64_t old_num_accumulates_) {
const framework::ExecutionContext& ctx, int64_t num_updates,
int64_t num_accumulates, int64_t old_num_accumulates) {
auto* out_old_num_accumulates = ctx.Output<Tensor>("out_old_num_accumulates");
auto* out_num_accumulates = ctx.Output<Tensor>("out_num_accumulates");
auto* out_num_updates = ctx.Output<Tensor>("out_num_updates");
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates_;
out_num_accumulates->data<int64_t>()[0] = num_accumulates_;
out_num_updates->data<int64_t>()[0] = num_updates_;
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates;
out_num_accumulates->data<int64_t>()[0] = num_accumulates;
out_num_updates->data<int64_t>()[0] = num_updates;
}
class AverageAccumulatesOp : public framework::OperatorWithKernel {
......@@ -177,7 +177,7 @@ class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
AverageAccumulates Operator.
Accumulate the sum of parameter whtin sliding window. The size of sliding window is
Accumulate the sum of parameter within sliding window. The size of sliding window is
determined by 'average_window', 'max_average_window' and 'min_average_window'.
Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'.
'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'.
......
......@@ -54,8 +54,9 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
float average_window = ctx.Attr<float>("average_window");
int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
min_average_window =
std::min<int64_t>(min_average_window, max_average_window);
PADDLE_ENFORCE_LE(min_average_window, max_average_window,
"min_average_window shouldn't be larger than "
"max_average_window");
// Get inputs
auto* param = ctx.Input<Tensor>("param");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册