提交 2ecc5622 编写于 作者: X Xin Pan 提交者: whs

small AverageOptimizer enhance. (#11761)

* small AverageOptimizer enhance.

* clean

* clean
上级 19e877ff
...@@ -19,28 +19,28 @@ namespace operators { ...@@ -19,28 +19,28 @@ namespace operators {
template <> template <>
void GetAccumulators<paddle::platform::CPUDeviceContext>( void GetAccumulators<paddle::platform::CPUDeviceContext>(
const framework::ExecutionContext& ctx, int64_t* num_updates_, const framework::ExecutionContext& ctx, int64_t* num_updates,
int64_t* num_accumulates_, int64_t* old_num_accumulates_) { int64_t* num_accumulates, int64_t* old_num_accumulates) {
auto* in_old_num_accumulates = ctx.Input<Tensor>("in_old_num_accumulates"); auto* in_old_num_accumulates = ctx.Input<Tensor>("in_old_num_accumulates");
auto* in_num_accumulates = ctx.Input<Tensor>("in_num_accumulates"); auto* in_num_accumulates = ctx.Input<Tensor>("in_num_accumulates");
auto* in_num_updates = ctx.Input<Tensor>("in_num_updates"); auto* in_num_updates = ctx.Input<Tensor>("in_num_updates");
*old_num_accumulates_ = in_old_num_accumulates->data<int64_t>()[0]; *old_num_accumulates = in_old_num_accumulates->data<int64_t>()[0];
*num_accumulates_ = in_num_accumulates->data<int64_t>()[0]; *num_accumulates = in_num_accumulates->data<int64_t>()[0];
*num_updates_ = in_num_updates->data<int64_t>()[0]; *num_updates = in_num_updates->data<int64_t>()[0];
} }
template <> template <>
void SetAccumulators<paddle::platform::CPUDeviceContext>( void SetAccumulators<paddle::platform::CPUDeviceContext>(
const framework::ExecutionContext& ctx, int64_t num_updates_, const framework::ExecutionContext& ctx, int64_t num_updates,
int64_t num_accumulates_, int64_t old_num_accumulates_) { int64_t num_accumulates, int64_t old_num_accumulates) {
auto* out_old_num_accumulates = ctx.Output<Tensor>("out_old_num_accumulates"); auto* out_old_num_accumulates = ctx.Output<Tensor>("out_old_num_accumulates");
auto* out_num_accumulates = ctx.Output<Tensor>("out_num_accumulates"); auto* out_num_accumulates = ctx.Output<Tensor>("out_num_accumulates");
auto* out_num_updates = ctx.Output<Tensor>("out_num_updates"); auto* out_num_updates = ctx.Output<Tensor>("out_num_updates");
out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates_; out_old_num_accumulates->data<int64_t>()[0] = old_num_accumulates;
out_num_accumulates->data<int64_t>()[0] = num_accumulates_; out_num_accumulates->data<int64_t>()[0] = num_accumulates;
out_num_updates->data<int64_t>()[0] = num_updates_; out_num_updates->data<int64_t>()[0] = num_updates;
} }
class AverageAccumulatesOp : public framework::OperatorWithKernel { class AverageAccumulatesOp : public framework::OperatorWithKernel {
...@@ -177,7 +177,7 @@ class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -177,7 +177,7 @@ class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
AverageAccumulates Operator. AverageAccumulates Operator.
Accumulate the sum of parameter whtin sliding window. The size of sliding window is Accumulate the sum of parameter within sliding window. The size of sliding window is
determined by 'average_window', 'max_average_window' and 'min_average_window'. determined by 'average_window', 'max_average_window' and 'min_average_window'.
Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'. Memory was shared by Input(in_sum_1) and Output(out_sum_1) which acts as an accumulator 'sum_1'.
'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'. 'sum_2', 'sum_3', 'num_accumulates', 'old_num_accumulates' and 'num_updates' were the same as 'sum_1'.
......
...@@ -54,8 +54,9 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> { ...@@ -54,8 +54,9 @@ class AverageAccumulatesKernel : public framework::OpKernel<T> {
float average_window = ctx.Attr<float>("average_window"); float average_window = ctx.Attr<float>("average_window");
int64_t max_average_window = ctx.Attr<int64_t>("max_average_window"); int64_t max_average_window = ctx.Attr<int64_t>("max_average_window");
int64_t min_average_window = ctx.Attr<int64_t>("min_average_window"); int64_t min_average_window = ctx.Attr<int64_t>("min_average_window");
min_average_window = PADDLE_ENFORCE_LE(min_average_window, max_average_window,
std::min<int64_t>(min_average_window, max_average_window); "min_average_window shouldn't be larger than "
"max_average_window");
// Get inputs // Get inputs
auto* param = ctx.Input<Tensor>("param"); auto* param = ctx.Input<Tensor>("param");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册