diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index 887258530cbe3cd4a7ef47569a764c8f7b49c5cd..c4e2c8bb88ec9c74bd782570c10fb217178c8e48 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -16,7 +16,7 @@ limitations under the License. */ #include // for sqrt in CPU and CUDA #include "paddle/framework/op_registry.h" #include "paddle/operators/detail/safe_ref.h" -#include "paddle/platform/transform.h" +#include "paddle/platform/for_range.h" namespace paddle { namespace operators { @@ -36,10 +36,12 @@ struct AdamFunctor { const T* lr_; const T* grad_; const T* param_; + T* param_out_; AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2, - T* mom2_out, const T* lr, const T* grad, const T* param) + T* mom2_out, const T* lr, const T* grad, const T* param, + T* param_out) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -51,11 +53,10 @@ struct AdamFunctor { moment2_out_(mom2_out), lr_(lr), grad_(grad), - param_(param) {} + param_(param), + param_out_(param_out) {} - // From param[i] --> param_out[i]; - inline HOSTDEVICE T operator()(const T& p) const { - size_t i = &p - param_; + inline HOSTDEVICE void operator()(size_t i) const { // Merge all memory access together. T g = grad_[i]; T mom1 = moment1_[i]; @@ -63,17 +64,18 @@ struct AdamFunctor { T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; + T p = param_[i]; // Calculation - lr = lr * sqrt(1 - beta2_pow) / (1 - beta1_pow); + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - T new_p = p - lr * (mom1 / (sqrt(mom2) + epsilon_)); + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; - return new_p; + param_out_[i] = p; } }; @@ -113,13 +115,11 @@ class AdamOpKernel : public framework::OpKernel { mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), lr.template data(), grad.template data(), - param.template data()); - - const T* in_ptr = param.template data(); - T* out_ptr = param_out.template mutable_data(ctx.GetPlace()); - platform::Transform trans; - trans(static_cast(ctx.device_context()), in_ptr, - in_ptr + param_out.numel(), out_ptr, functor); + param.template data(), + param_out.template mutable_data(ctx.GetPlace())); + platform::ForRange for_range( + static_cast(ctx.device_context()), param.numel()); + for_range(functor); } }; diff --git a/paddle/platform/for_range.h b/paddle/platform/for_range.h new file mode 100644 index 0000000000000000000000000000000000000000..6ba6b01076103cf5660718b32a1989c14bc6dd70 --- /dev/null +++ b/paddle/platform/for_range.h @@ -0,0 +1,85 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace platform { + +template +struct ForRange { + ForRange(const DeviceContext& dev_ctx, size_t limit); + + template + void operator()(Function func) const; +}; + +template <> +struct ForRange { + ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {} + + template + void operator()(Function func) const { + for (size_t i = 0; i < limit_; ++i) { + func(i); + } + } + + size_t limit_; +}; + +#ifdef __NVCC__ +template +__global__ static void ForRangeElemwiseOpGridIsOne(Function func) { + size_t idx = static_cast(threadIdx.x); + func(idx); +} + +template +__global__ static void ForRangeElemwiseOp(Function func, int limit) { + size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx < limit) { + func(idx); + } +} + +template <> +struct ForRange { + ForRange(const CUDADeviceContext& dev_ctx, size_t limit) + : dev_ctx_(dev_ctx), limit_(static_cast(limit)) {} + + template + inline void operator()(Function func) const { + constexpr size_t num_threads = 1024; + int block_size = limit_ <= num_threads ? limit_ : num_threads; + int grid_size = (limit_ + num_threads - 1) / num_threads; + + if (grid_size == 1) { + ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>( + func); + } else { + ForRangeElemwiseOp<<>>( + func, limit_); + } + } + + const CUDADeviceContext& dev_ctx_; + int limit_; +}; + +#endif + +} // namespace platform +} // namespace paddle