diff --git a/paddle/fluid/operators/average_accumulates_op.cc b/paddle/fluid/operators/average_accumulates_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..808693b61c3670c2c882e104506db5e1e7f2eb9a --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.cc @@ -0,0 +1,152 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/average_accumulates_op.h" + +namespace paddle { +namespace operators { + +template <> +void getAccumulators( + const framework::ExecutionContext& ctx, int64_t& num_updates_, + int64_t& num_accumulates_, int64_t& old_num_accumulates_) { + auto* in_old_num_accumulates = ctx.Input("old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("num_accumulates"); + auto* in_num_updates = ctx.Input("num_updates"); + + old_num_accumulates_ = in_old_num_accumulates->data()[0]; + num_accumulates_ = in_num_accumulates->data()[0]; + num_updates_ = in_num_updates->data()[0]; +} + +template <> +void setAccumulators( + const framework::ExecutionContext& ctx, int64_t num_updates_, + int64_t num_accumulates_, int64_t old_num_accumulates_) { + auto* out_old_num_accumulates = ctx.Output("old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("num_accumulates"); + auto* out_num_updates = ctx.Output("num_updates"); + + out_old_num_accumulates->data()[0] = old_num_accumulates_; + out_num_accumulates->data()[0] = num_accumulates_; + out_num_updates->data()[0] = num_updates_; +} + +class AverageAccumulatesOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("Param"), + "Input (Param) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("Grad"), + "Input (Grad) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("sum_1"), + "Input (sum_1) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("sum_2"), + "Input (sum_2) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("sum_3"), + "Input (sum_3) of average_accumulates op should not be null."); + PADDLE_ENFORCE(ctx->HasInput("num_accumulates"), + "Input (num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasInput("old_num_accumulates"), + "Input (old_num_accumulates) of average_accumulates op " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("num_updates"), + "Input (num_updates) of average_accumulates op should not be null."); + + PADDLE_ENFORCE( + ctx->HasOutput("sum_1"), + "Output (sum_1) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("sum_2"), + "Output (sum_2) of average_accumulates op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("sum_3"), + "Output (sum_3) of average_accumulates op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("num_accumulates"), + "Output (num_accumulates) of average_accumulates op should " + "not be null."); + PADDLE_ENFORCE(ctx->HasOutput("old_num_accumulates"), + "Output (old_num_accumulates) of average_accumulates op " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("num_updates"), + "Output (num_updates) of average_accumulates op should not be null."); + + auto in_dim = ctx->GetInputDim("Param"); + + ctx->SetOutputDim("sum_1", in_dim); + ctx->SetOutputDim("sum_2", in_dim); + ctx->SetOutputDim("sum_3", in_dim); + ctx->SetOutputDim("num_accumulates", {1}); + ctx->SetOutputDim("old_num_accumulates", {1}); + ctx->SetOutputDim("num_updates", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Param")->type()), + ctx.GetPlace()); + } +}; + +class AverageAccumulatesOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AverageAccumulatesOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("sum_1", ""); + AddInput("sum_2", ""); + AddInput("sum_3", ""); + AddInput("num_accumulates", ""); + AddInput("old_num_accumulates", ""); + AddInput("num_updates", ""); + + AddOutput("sum_1", ""); + AddOutput("sum_2", ""); + AddOutput("sum_3", ""); + AddOutput("num_accumulates", ""); + AddOutput("old_num_accumulates", ""); + AddOutput("num_updates", ""); + + AddAttr("", "average_window"); + AddAttr("", "max_average_window"); + AddAttr("", "min_average_window"); + + AddComment(R"DOC( +AverageAccumulates Operator. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(average_accumulate, ops::AverageAccumulatesOp, + ops::AverageAccumulatesOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + average_accumulate, + ops::AverageAccumulatesKernel, + ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.cu b/paddle/fluid/operators/average_accumulates_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..56f2f02fd23cc78c9fc4148ee79871c63d95280e --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.cu @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/average_accumulates_op.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { +template <> +void getAccumulators( + const framework::ExecutionContext& ctx, int64_t& num_updates_, + int64_t& num_accumulates_, int64_t& old_num_accumulates_) { + auto* in_old_num_accumulates = ctx.Input("old_num_accumulates"); + auto* in_num_accumulates = ctx.Input("num_accumulates"); + auto* in_num_updates = ctx.Input("num_updates"); + + memory::Copy(platform::CPUPlace(), &old_num_accumulates_, + platform::CUDAPlace(), in_old_num_accumulates->data(), + sizeof(int64_t)); + memory::Copy(platform::CPUPlace(), &num_accumulates_, platform::CUDAPlace(), + in_old_num_accumulates->data(), sizeof(int64_t)); + memory::Copy(platform::CPUPlace(), &num_updates_, platform::CUDAPlace(), + in_num_updates->data(), sizeof(int64_t)); +} + +template <> +void setAccumulators( + const framework::ExecutionContext& ctx, int64_t num_updates_, + int64_t num_accumulates_, int64_t old_num_accumulates_) { + auto* out_old_num_accumulates = ctx.Output("old_num_accumulates"); + auto* out_num_accumulates = ctx.Output("num_accumulates"); + auto* out_num_updates = ctx.Output("num_updates"); + + memory::Copy(platform::CUDAPlace(), out_old_num_accumulates->data(), + platform::CPUPlace(), &old_num_accumulates_, sizeof(int64_t)); + memory::Copy(platform::CUDAPlace(), out_num_accumulates->data(), + platform::CPUPlace(), &num_accumulates_, sizeof(int64_t)); + memory::Copy(platform::CUDAPlace(), out_num_updates->data(), + platform::CPUPlace(), &num_updates_, sizeof(int64_t)); +} +} +} + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + average_accumulate, + ops::AverageAccumulatesKernel, + ops::AverageAccumulatesKernel); diff --git a/paddle/fluid/operators/average_accumulates_op.h b/paddle/fluid/operators/average_accumulates_op.h new file mode 100644 index 0000000000000000000000000000000000000000..73814dd24b996af758046c7c1bf486190719b40d --- /dev/null +++ b/paddle/fluid/operators/average_accumulates_op.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenVector = framework::EigenVector; + +template +void getAccumulators(const framework::ExecutionContext& ctx, + int64_t& num_updates_, int64_t& num_accumulates_, + int64_t& old_num_accumulates_); + +template +void setAccumulators(const framework::ExecutionContext& ctx, + int64_t num_updates_, int64_t num_accumulates_, + int64_t old_num_accumulates_); + +template +class AverageAccumulatesKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + static const int64_t kMaxNumAccumulates = 16384; + // accumulators + int64_t num_updates_ = 0; + int64_t num_accumulates_ = 0; + int64_t old_num_accumulates_ = 0; + // attrs + int64_t min_average_window_; + int64_t max_average_window_; + float average_window_; + + auto* param = ctx.Input("Param"); + auto* in_sum_1 = ctx.Input("sum_1"); + auto* in_sum_2 = ctx.Input("sum_2"); + auto* in_sum_3 = ctx.Input("sum_3"); + + auto* out_sum_1 = ctx.Output("sum_1"); + auto* out_sum_2 = ctx.Output("sum_2"); + auto* out_sum_3 = ctx.Output("sum_3"); + + getAccumulators(ctx, num_updates_, num_accumulates_, + old_num_accumulates_); + average_window_ = ctx.Attr("average_window"); + max_average_window_ = + ctx.Attr("max_average_window"); // default bach number + min_average_window_ = + ctx.Attr("min_average_window"); // default 10000L + min_average_window_ = + std::min(min_average_window_, max_average_window_); + + auto param_tensor = EigenVector::Flatten(*param); + auto in_sum_1_tensor = EigenVector::Flatten(*in_sum_1); + auto in_sum_2_tensor = EigenVector::Flatten(*in_sum_2); + auto in_sum_3_tensor = EigenVector::Flatten(*in_sum_3); + auto out_sum_1_tensor = EigenVector::Flatten(*out_sum_1); + auto out_sum_2_tensor = EigenVector::Flatten(*out_sum_2); + auto out_sum_3_tensor = EigenVector::Flatten(*out_sum_3); + + auto& place = *ctx.template device_context().eigen_device(); + math::SetConstant constant_functor; + // start batch + ++num_updates_; + ++num_accumulates_; + + // update + out_sum_1_tensor.device(place) = in_sum_1_tensor + param_tensor; + + out_sum_2_tensor.device(place) = in_sum_2_tensor; + out_sum_3_tensor.device(place) = in_sum_3_tensor; + // needSpecialTraversal + if (num_updates_ % kMaxNumAccumulates == 0) { + out_sum_2_tensor.device(place) = in_sum_2_tensor + in_sum_1_tensor; + constant_functor(ctx.template device_context(), out_sum_1, + 0.0); + } + + if (num_accumulates_ >= min_average_window_ && + num_accumulates_ >= std::min(max_average_window_, + num_updates_ * average_window_)) { + out_sum_3_tensor.device(place) = in_sum_1_tensor + in_sum_2_tensor; + constant_functor(ctx.template device_context(), out_sum_1, + 0.0); + constant_functor(ctx.template device_context(), out_sum_2, + 0.0); + + // finishBatch + old_num_accumulates_ = num_accumulates_; + num_accumulates_ = 0; + } + setAccumulators(ctx, num_updates_, num_accumulates_, + old_num_accumulates_); + } +}; + +} // namespace operators +} // namespace paddle