diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index 8f30dd5b2e68a4d15d849141b175b8eae503b170..65be35843bdf99c68163b3e62f8dcbc2648a3a2d 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -13,46 +13,158 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/optimizers/lars_momentum_op.h" -#include "paddle/fluid/operators/optimizers/momentum_op.h" namespace paddle { namespace operators { +class LarsMomentumOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("Velocity"), "Input", "Velocity", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasInputs("LearningRate"), "Input", "LearningRate", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasOutputs("ParamOut"), "Output", "ParamOut", + "LarsMomentum"); + OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"), "Output", "VelocityOut", + "LarsMomentum"); + PADDLE_ENFORCE_EQ( + ctx->GetInputsVarType("Param").front(), + framework::proto::VarType::LOD_TENSOR, + platform::errors::InvalidArgument( + "The input var's type should be LoDTensor, but the received is %s", + ctx->GetInputsVarType("Param").front())); + + auto lr_dims = ctx->GetInputsDim("LearningRate"); + auto grad_dim = ctx->GetInputsDim("Grad"); + auto param_dim = ctx->GetInputsDim("Param"); + auto velocity_dim = ctx->GetInputsDim("Velocity"); + auto lars_weight_decays = + ctx->Attrs().Get>("lars_weight_decay"); + auto multi_precision = ctx->Attrs().Get("multi_precision"); + + PADDLE_ENFORCE_EQ( + param_dim.size(), grad_dim.size(), + platform::errors::InvalidArgument( + "Input(Param) and Input(Grad) of LarsMomentumOp should have " + "same quantity. But number of Param is [%d] and Grad is [%d].", + param_dim.size(), grad_dim.size())); + PADDLE_ENFORCE_EQ( + param_dim.size(), velocity_dim.size(), + platform::errors::InvalidArgument( + "Input(Param) and Input(Velocity) of LarsMomentumOp should " + "have same quantity. But number of Param is [%d] and Velocity " + "is [%d].", + param_dim.size(), velocity_dim.size())); + PADDLE_ENFORCE_EQ( + lars_weight_decays.size(), grad_dim.size(), + platform::errors::InvalidArgument( + "Attr(Lars_weight_decay) and " + "Input(Grad) of LarsMomentumOp should have same quantity. " + "But number of Lars_weight_decay is [%d] and Grad is [%d].", + lars_weight_decays.size(), grad_dim.size())); + + if (multi_precision) { + OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam", + "LarsMomentumMultiPrecision"); + OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"), "Output", + "MasterParamOut", "LarsMomentumMultiPrecision"); + } + for (size_t i = 0; i < lr_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1, + platform::errors::InvalidArgument( + "Learning_rate should be a scalar. But Received " + "LearningRate's dim [%s]", + framework::product(lr_dims[i]))); + } + + for (size_t i = 0; i < param_dim.size(); ++i) { + PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i], + framework::proto::VarType::LOD_TENSOR, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx->Inputs("Grad")[i].front(), + ctx->GetInputsVarType("Grad")[i])); + PADDLE_ENFORCE_EQ( + param_dim[i], grad_dim[i], + platform::errors::InvalidArgument( + "Input(Param) and Input(Grad) input of LarsMomentumOp shall " + "have same dimension. But Param`s dim is [%s] and Grad's dim " + "is [%s].", + param_dim[i], grad_dim[i])); + PADDLE_ENFORCE_EQ( + param_dim[i], velocity_dim[i], + platform::errors::InvalidArgument( + "Input(Param) and Input(Velocity) of LarsMomentumOp shall have " + "same dimension. But Param dim [%s] differs with Velocity dim " + "[%s].", + param_dim[i], velocity_dim[i])); + } + ctx->SetOutputsDim("ParamOut", param_dim); + ctx->SetOutputsDim("VelocityOut", param_dim); + if (ctx->HasOutputs("MasterParamOut")) { + ctx->SetOutputsDim("MasterParamOut", param_dim); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Param", "(LoDTensor, default LoDTensor) " - "Input parameter that has to be updated"); + "Input parameter that has to be updated") + .AsDuplicable(); AddInput("Grad", "(LoDTensor, default LoDTensor) " - "Input gradient of the parameter"); + "Input gradient of the parameter") + .AsDuplicable(); AddInput("Velocity", "(LoDTensor, default LoDTensor) " "Input velocity (corresponding to the parameter) " - "that has to be updated"); + "that has to be updated") + .AsDuplicable(); AddInput("LearningRate", "(LoDTensor, default LoDTensor) " - "Input learning rate"); - AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); - + "Input learning rate") + .AsDuplicable(); + AddInput("MasterParam", "FP32 master weight for AMP.") + .AsDuplicable() + .AsDispensable(); AddOutput("ParamOut", "(LoDTensor) This output is updated parameter. " - "It shared memory with Input(Param)."); + "It shared memory with Input(Param).") + .AsDuplicable(); AddOutput("VelocityOut", "(LoDTensor) This output is updated velocity. " - "It shared memory with Input(Velocity)."); + "It shared memory with Input(Velocity).") + .AsDuplicable(); AddOutput("MasterParamOut", "The updated FP32 master weight for AMP. " "It shared memory with Input(MasterParam).") + .AsDuplicable() .AsDispensable(); - AddAttr("mu", "(float) Momentum coefficient"); AddAttr("lars_coeff", "(float, default 0.001) LARS coefficient.") .SetDefault(0.001); - AddAttr("lars_weight_decay", - "(float, default 0.0005) LARS weight decay") - .SetDefault(0.0005); + AddAttr>( + "lars_weight_decay", + "(std::vector, default 0.0005) LARS weight decay params") + .SetDefault({0.0005}); AddAttr("epsilon", "(float, default 0.0) epsilon to avoid Division by Zero.") .SetDefault(0.0); @@ -96,7 +208,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { namespace ops = paddle::operators; REGISTER_OPERATOR( - lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker, + lars_momentum, ops::LarsMomentumOp, ops::LarsMomentumOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::LarsMomentumOpVarTypeInference); diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index 3e7023bd1260f538c29d811705d8c917aa7f95e0..caefd496978af2b6c34b61572fbb33725b33fab4 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -18,18 +18,8 @@ limitations under the License. */ #include "paddle/fluid/operators/optimizers/lars_momentum_op.h" #include "paddle/fluid/platform/fast_divmod.h" -#if defined(__NVCC__) && CUDA_VERSION >= 11000 -/* Once CUDA_VERSION is beyond 11.0, cooperative_groups can be involved in - without adding --rdc=true compile flag, then L2_norm cuda kernel can be - set as a __device__ kernel rather than global kernel. On the contrary, - the compile flag shall be set in old version, which may affect the cuda - kernel performance in paddle, consequently, L2_norm kernel shall be set - as a __global__ kernel. -*/ +#if CUDA_VERSION >= 11000 #include -#define LARS_FUNCTION_FLAG __device__ -#else -#define LARS_FUNCTION_FLAG __global__ #endif #ifdef __HIPCC__ @@ -38,6 +28,8 @@ limitations under the License. */ #define LARS_BLOCK_SIZE 512 #endif +#define LARS_MAX_MERGED_OPS 150 + namespace paddle { namespace operators { @@ -53,6 +45,43 @@ __device__ __forceinline__ double Fma(double x, double y, double z) { return fma(x, y, z); } +template +class LarsThreadConfig { + public: + int grid_for_norm; + int grid_for_lars; +#if CUDA_VERSION >= 11000 + + private: + int grid_stride; + + public: + explicit LarsThreadConfig(int64_t numel, int sm_num, int num_blocks_per_sm) { + int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; + grid_for_lars = + std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE); + grid_stride = LARS_BLOCK_SIZE * grid_for_lars; + } + + int GetRepeatTimes(int64_t numel) { + return (numel + grid_stride - 1) / grid_stride - 1; + } +#else + int repeat_times; + explicit LarsThreadConfig(const int64_t numel) { + int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; + grid_for_norm = std::min(grid, LARS_BLOCK_SIZE); + const int grid_stride = grid_for_norm * LARS_BLOCK_SIZE; + repeat_times = (numel + grid_stride - 1) / grid_stride - 1; + // Determine to read 4 fp16 or float data once, but 2 double data once. + grid_for_lars = + std::is_same::value + ? (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1) + : (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2); + } +#endif +}; + template __device__ inline void VectorizeLarsUpdate( const T* __restrict__ grad, const MT* __restrict__ param, @@ -85,7 +114,6 @@ __device__ inline void VectorizeLarsUpdate( VecType grad_data = grad_vec[i]; VecMType param_data = param_vec[i]; VecMType velocity_data = velocity_vec[i]; - #pragma unroll for (int j = 0; j < VecSize; ++j) { MT grad_val = static_cast(grad_data[j]) * rescale_grad; @@ -116,41 +144,49 @@ __device__ inline void VectorizeLarsUpdate( } } +#if CUDA_VERSION >= 11000 +/* Once CUDA_VERSION is beyond 11, cooperative_groups can be involved in without + --rdc=true compile flag, then L2_norm kernel can be set with __device__ and + cooperative_groups::grid_group also can be involved. Otherwise, adding this + flag may affect much, L2_norm kernel shall be set with __global__.*/ +// TODO(limingshu): declaration of cooperative_groups wapper is invalid in host. +template +__forceinline__ __device__ void L2NormKernel( + const cooperative_groups::grid_group* cg, +#else template -LARS_FUNCTION_FLAG void L2NormKernel( +__global__ void L2NormKernel( +#endif const T* __restrict__ p_data, const T* __restrict__ g_data, - MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, - const int repeat_times, const int64_t numel, const MT rescale_grad, + MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel, + const int repeat_times, const MT rescale_grad, const int thresh = 0, MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { + __shared__ MT s_buffer[2]; int tid = threadIdx.x + blockDim.x * blockIdx.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x; - const MT rescale_grad_pow = rescale_grad * rescale_grad; - __shared__ MT s_buffer[2]; + const MT rescale_pow = rescale_grad * rescale_grad; s_buffer[0] = static_cast(0); s_buffer[1] = static_cast(0); - MT p_tmp_val = static_cast(0); - MT g_tmp_val = static_cast(0); + MT p_tmp = static_cast(0); + MT g_tmp = static_cast(0); if (repeat_times == 0) { if (tid < numel) { - p_tmp_val = static_cast(p_data[tid]); - g_tmp_val = static_cast(g_data[tid]); + p_tmp = static_cast(p_data[tid]); + g_tmp = static_cast(g_data[tid]); } - s_buffer[0] += math::blockReduceSum(p_tmp_val * p_tmp_val, FINAL_MASK); - s_buffer[1] += math::blockReduceSum(g_tmp_val * g_tmp_val, FINAL_MASK); + s_buffer[0] += math::blockReduceSum(p_tmp * p_tmp, FINAL_MASK); + s_buffer[1] += math::blockReduceSum(g_tmp * g_tmp, FINAL_MASK); } else { - /* To avoid occupy too much temp buffer. Hence, slice the whole data into 2 - parts, the front of them whose quantity is excatly multiple of grid-thread - number, and this part of data is delt in for loop, the rest of data is delt - with another step to avoid visiting data address beyond bound. */ + /* Avoid occupy too much temp buffer. Slice the whole data into 2 parts, + the front of data whose quantity is excatly multiple of grid-thread + number, and delt in for loop, the rest is delt with another step. */ for (int i = 0; i < repeat_times; ++i) { - p_tmp_val = static_cast(p_data[tid]); - g_tmp_val = static_cast(g_data[tid]); + p_tmp = static_cast(p_data[tid]); + g_tmp = static_cast(g_data[tid]); tid += grid_stride; - s_buffer[0] += - math::blockReduceSum(p_tmp_val * p_tmp_val, FINAL_MASK); - s_buffer[1] += - math::blockReduceSum(g_tmp_val * g_tmp_val, FINAL_MASK); + s_buffer[0] += math::blockReduceSum(p_tmp * p_tmp, FINAL_MASK); + s_buffer[1] += math::blockReduceSum(g_tmp * g_tmp, FINAL_MASK); __syncthreads(); } MT p_val = 0; @@ -168,69 +204,46 @@ LARS_FUNCTION_FLAG void L2NormKernel( p_buffer[blockIdx.x] = s_buffer[0]; g_buffer[blockIdx.x] = s_buffer[1]; } - #if CUDA_VERSION >= 11000 - // Grid sync for completely writring partial result back to gloabl memory - const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); - cg.sync(); - MT p_partial_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; - MT g_partial_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; - *p_n = Sqrt(math::blockReduceSum(p_partial_sum, FINAL_MASK)); - *g_n = Sqrt(rescale_grad_pow * - math::blockReduceSum(g_partial_sum, FINAL_MASK)); + cg->sync(); // Grid sync for writring partial result to gloabl memory + MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; + MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; + *p_n = Sqrt(math::blockReduceSum(p_part_sum, FINAL_MASK)); + *g_n = Sqrt(rescale_pow * math::blockReduceSum(g_part_sum, FINAL_MASK)); #endif } template -__global__ void MomentumLarsKernel( +__forceinline__ __device__ void MomentumUpdate( const T* __restrict__ param, const T* __restrict__ grad, const MT* __restrict__ velocity, T* param_out, MT* velocity_out, const MT* __restrict__ master_param, MT* __restrict__ master_param_out, - const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer, - MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, - const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, - const int repeat_times, const int thresh, const int64_t numel) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - int grid_stride = gridDim.x * LARS_BLOCK_SIZE; -#if CUDA_VERSION >= 11000 - MT param_norm = static_cast(0); - MT grad_norm = static_cast(0); - L2NormKernel(param, grad, p_buffer, g_buffer, repeat_times, numel, - rescale_grad, ¶m_norm, &grad_norm); -#else - const MT rescale_grad_pow = rescale_grad * rescale_grad; - MT param_parital_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; - MT grad_parital_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; - __syncthreads(); - MT param_norm = - Sqrt(math::blockReduceSum(param_parital_norm, FINAL_MASK)); - MT grad_norm = Sqrt(rescale_grad_pow * - math::blockReduceSum(grad_parital_norm, FINAL_MASK)); -#endif - + const MT* __restrict__ learning_rate, const MT mu, + const MT lars_weight_decay, const MT lars_coeff, const MT epsilon, + const MT rescale_grad, const MT param_norm, const MT grad_norm, + const int tid, const int grid_stride, const int64_t numel, + const bool is_amp) { const MT lr = learning_rate[0]; MT local_lr = lr; if (lars_weight_decay > static_cast(0)) { local_lr = lr * lars_coeff * param_norm / - (Fma(lars_weight_decay, param_norm, grad_norm) + epsilon); + (fma(lars_weight_decay, param_norm, grad_norm) + epsilon); } - - if (master_param_out) { - VectorizeLarsUpdate(grad, master_param, velocity, param_out, - velocity_out, mu, local_lr, - lars_weight_decay, rescale_grad, tid, - grid_stride, numel, master_param_out); + if (is_amp) { + VectorizeLarsUpdate( + grad, master_param, velocity, param_out, velocity_out, mu, local_lr, + lars_weight_decay, rescale_grad, tid, grid_stride, numel, + master_param_out); } else { if (std::is_same::value || std::is_same::value) { - // As for multiple-precision, type T and MT cannot be more than fp16 or - // fp32, Then, the maximum data IO size could be set to 4. - VectorizeLarsUpdate( + /* TODO(limingshu): pointer cast may damage memory accessing for fp16 */ + VectorizeLarsUpdate( grad, reinterpret_cast(param), velocity, param_out, velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, grid_stride, numel); } else { - VectorizeLarsUpdate( + VectorizeLarsUpdate( grad, reinterpret_cast(param), velocity, param_out, velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, grid_stride, numel); @@ -238,144 +251,278 @@ __global__ void MomentumLarsKernel( } } +#if CUDA_VERSION >= 11000 +template +struct LarsParamWarpper { + int64_t numel_arr[LARS_MAX_MERGED_OPS]; + int repeat_arr[LARS_MAX_MERGED_OPS]; + const T* __restrict__ p_arr[LARS_MAX_MERGED_OPS]; + const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS]; + const MT* __restrict__ v_arr[LARS_MAX_MERGED_OPS]; + const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS]; + const MT* __restrict__ master_p_arr[LARS_MAX_MERGED_OPS]; + T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS]; + MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS]; + MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS]; + MT weight_decay_arr[LARS_MAX_MERGED_OPS]; +}; + +template +__global__ void MergedMomentumLarsKernel(LarsParamWarpper* lars_warpper, + MT* __restrict__ p_buffer, + MT* __restrict__ g_buffer, + const int op_num, const MT mu, + const MT lars_coeff, const MT epsilon, + const MT rescale_grad, + const bool is_amp) { + int grid_stride = gridDim.x * LARS_BLOCK_SIZE; + int tid = threadIdx.x + blockIdx.x * blockDim.x; + const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); + for (int i = 0; i < op_num; ++i) { + int numel = lars_warpper->numel_arr[i]; + MT param_norm = static_cast(0); + MT grad_norm = static_cast(0); + L2NormKernel(&cg, lars_warpper->p_arr[i], lars_warpper->g_arr[i], + p_buffer, g_buffer, numel, lars_warpper->repeat_arr[i], + rescale_grad, 0, ¶m_norm, &grad_norm); + MomentumUpdate( + lars_warpper->p_arr[i], lars_warpper->g_arr[i], + lars_warpper->v_out_arr[i], lars_warpper->p_out_arr[i], + lars_warpper->v_out_arr[i], lars_warpper->master_p_arr[i], + lars_warpper->master_p_out_arr[i], lars_warpper->lr_arr[i], mu, + lars_warpper->weight_decay_arr[i], lars_coeff, epsilon, rescale_grad, + param_norm, grad_norm, tid, grid_stride, numel, is_amp); + } +} +#endif + +template +__global__ void MomentumLarsKernel( + const T* __restrict__ param, const T* __restrict__ grad, + const MT* __restrict__ velocity, T* param_out, MT* velocity_out, + const MT* __restrict__ master_param, MT* __restrict__ master_param_out, + const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer, + MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, + const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, + const int repeat_times, const int thresh, const int64_t numel, + const bool is_amp) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int grid_stride = gridDim.x * LARS_BLOCK_SIZE; +#if CUDA_VERSION >= 11000 + const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); + MT param_norm = static_cast(0); + MT grad_norm = static_cast(0); + L2NormKernel(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times, + rescale_grad, gridDim.x, ¶m_norm, &grad_norm); +#else + const MT rescale_grad_pow = rescale_grad * rescale_grad; + MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0; + MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0; + __syncthreads(); + MT param_norm = Sqrt(math::blockReduceSum(param_part_norm, FINAL_MASK)); + MT grad_norm = Sqrt(rescale_grad_pow * + math::blockReduceSum(grad_part_norm, FINAL_MASK)); +#endif + MomentumUpdate(param, grad, velocity, param_out, velocity_out, + master_param, master_param_out, learning_rate, mu, + lars_weight_decay, lars_coeff, epsilon, rescale_grad, + param_norm, grad_norm, tid, grid_stride, numel, is_amp); +} + +template +inline void SeparatedLarsMomentumOpCUDAKernel( + const platform::CUDADeviceContext& cuda_ctx, const T* param_data, + T* param_out_data, const MT* velocity_data, MT* velocity_out_data, + const T* grad_data, const MT* lr, MT* p_buffer, MT* g_buffer, const MT mu, + const MT lars_coeff, const MT weight_decay, const MT epsilon, + const MT rescale_grad, const int64_t numel, const MT* master_param_data, + MT* master_out_data, const bool is_amp) { + LarsThreadConfig lars_thread_config(numel); + L2NormKernel<<>>( + param_data, grad_data, p_buffer, g_buffer, numel, + lars_thread_config.repeat_times, rescale_grad); + + MomentumLarsKernel<<>>( + param_data, grad_data, velocity_data, param_out_data, velocity_out_data, + master_param_data, master_out_data, lr, p_buffer, g_buffer, mu, + lars_coeff, weight_decay, epsilon, rescale_grad, 0, + lars_thread_config.grid_for_norm, numel, is_amp); +} + template class LarsMomentumOpCUDAKernel : public framework::OpKernel { using MT = MultiPrecisionType; public: void Compute(const framework::ExecutionContext& ctx) const override { - const bool multi_precision = ctx.Attr("multi_precision"); - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto grad = ctx.Input("Grad"); - auto learning_rate = ctx.Input("LearningRate"); - - int64_t numel = param->numel(); - int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE; - const framework::Tensor* master_param = nullptr; - framework::Tensor* master_param_out = nullptr; - const MT* master_param_data = nullptr; - MT* master_param_out_data = nullptr; - - if (multi_precision) { - bool has_master = - ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); - PADDLE_ENFORCE_EQ(has_master, true, - platform::errors::InvalidArgument( - "The Input(MasterParam) and Output(MasterParamOut) " - "should not be null when " - "the attr `multi_precision` is true")); - master_param = ctx.Input("MasterParam"); - master_param_out = ctx.Output("MasterParamOut"); - master_param_data = master_param->data(); - master_param_out_data = - master_param_out->mutable_data(ctx.GetPlace()); - } - MT mu = static_cast(ctx.Attr("mu")); - MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); - MT lars_weight_decay = - static_cast(ctx.Attr("lars_weight_decay")); - MT epsilon = static_cast(ctx.Attr("epsilon")); - MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); - - auto* param_data = param->data(); - auto* grad_data = grad->data(); - auto* velocity_data = velocity->data(); - auto* lr = learning_rate->data(); - auto& cuda_ctx = ctx.template device_context(); - T* param_out_data = param_out->mutable_data(ctx.GetPlace()); - MT* velocity_out_data = velocity_out->mutable_data(ctx.GetPlace()); - -#if CUDA_VERSION >= 11000 - /* - Once model trainning with lars optimizer, whose principal implementation - is achieved by following two steps: - 1. Figure out the L2 norm statistic result of grad data and param data. - 2. Update param and velocity data with usage of L2 norm statistic result. - - Orignally, these two steps were fulfilled by respective eigen function and - cuda kernel, however the overhead of eigen function occupied much ratio in - total, consequently affect the performance of lars op, make it necessary - to combine 2 steps into one cuda kernel. - Since the step1 is l2 norm statistic, grid level reduce is needed. To - achieve this and continuous calculation of step 2 in only one global - lanuch, essential basis is to control all grid-threads while running. Apart - from normal lanuch form, cuda9.0 provides `cudaLaunchCooperativeKernel` - api : - - The thread quantity shall less than pyhsical SM limited threads - - Launches a device function where thread blocks can cooperate and - synchronize as they execute. - */ - // Figure out how many blocks can be active in each sm. int num_blocks_per_sm = 0; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, - MomentumLarsKernel, - LARS_BLOCK_SIZE, sizeof(MT)); + bool multi_precision = ctx.Attr("multi_precision"); + auto& cuda_ctx = ctx.template device_context(); int sm_num = cuda_ctx.GetSMCount(); - int grid_real = - std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE); framework::Tensor tmp_buffer_t = ctx.AllocateTmpTensor( {LARS_BLOCK_SIZE << 1}, cuda_ctx); auto* p_buffer = tmp_buffer_t.mutable_data(ctx.GetPlace()); auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; - int grid_stride = LARS_BLOCK_SIZE * grid; - int repeat_times = (numel + grid_stride - 1) / grid_stride - 1; - int thresh = 0; - - // Uniform kernel parameter for cudaLaunchCooperativeKernel - void* cuda_param[] = { - reinterpret_cast(¶m_data), - reinterpret_cast(&grad_data), - reinterpret_cast(&velocity_data), - reinterpret_cast(¶m_out_data), - reinterpret_cast(&velocity_out_data), - reinterpret_cast(&master_param_data), - reinterpret_cast(&master_param_out_data), - reinterpret_cast(&lr), - reinterpret_cast(&p_buffer), - reinterpret_cast(&g_buffer), - reinterpret_cast(&mu), - reinterpret_cast(&lars_coeff), - reinterpret_cast(&lars_weight_decay), - reinterpret_cast(&epsilon), - reinterpret_cast(&rescale_grad), - reinterpret_cast(&repeat_times), - reinterpret_cast(&thresh), // Just a placeholder - reinterpret_cast(&numel)}; - // Lanuch all sm theads. - cudaLaunchCooperativeKernel( - reinterpret_cast(MomentumLarsKernel), grid_real, - LARS_BLOCK_SIZE, cuda_param, 0, cuda_ctx.stream()); -#else - // Determine to read 4 fp16 or float data once, but 2 double data once. - int grid_lars = - sizeof(T) < sizeof(double) - ? (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2) - : (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1); - int grid_norm = std::min(grid, LARS_BLOCK_SIZE); - framework::Tensor p_buffer_t = - ctx.AllocateTmpTensor( - {LARS_BLOCK_SIZE << 1}, cuda_ctx); - auto* p_buffer = p_buffer_t.mutable_data(ctx.GetPlace()); - auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; - - const int grid_stride = LARS_BLOCK_SIZE * grid_norm; - const int repeat_times = (numel + grid_stride - 1) / grid_stride - 1; - - L2NormKernel<<>>( - param_data, grad_data, p_buffer, g_buffer, repeat_times, numel, - rescale_grad); + MT mu = static_cast(ctx.Attr("mu")); + MT lars_coeff = static_cast(ctx.Attr("lars_coeff")); + MT epsilon = static_cast(ctx.Attr("epsilon")); + MT rescale_grad = static_cast(ctx.Attr("rescale_grad")); - MomentumLarsKernel< - T, MT><<>>( - param_data, grad_data, velocity_data, param_out_data, velocity_out_data, - master_param_data, master_param_out_data, lr, p_buffer, g_buffer, mu, - lars_coeff, lars_weight_decay, epsilon, rescale_grad, 0, grid_norm, - numel); // 0 is just a placeholder. + auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); + auto grad = ctx.MultiInput("Grad"); + auto param = ctx.MultiInput("Param"); + auto velocity = ctx.MultiInput("Velocity"); + auto param_out = ctx.MultiOutput("ParamOut"); + auto velocity_out = ctx.MultiOutput("VelocityOut"); + auto learning_rate = ctx.MultiInput("LearningRate"); + auto master_param = ctx.MultiInput("MasterParam"); + auto master_param_out = + ctx.MultiOutput("MasterParamOut"); + + int op_num = grad.size(); +#if CUDA_VERSION >= 11000 + if (op_num > 1) { + LarsParamWarpper lars_warpper; + PADDLE_ENFORCE_LT( + op_num, LARS_MAX_MERGED_OPS, + platform::errors::InvalidArgument( + "The maximum number of merged-ops supported is (%d), but" + "lars op required for trainning this model is (%d)\n", + LARS_MAX_MERGED_OPS, op_num)); + + /* Implementation of lars optimizer consists of following two steps: + 1. Figure out the L2 norm statistic result of grad data and param data. + 2. Update param and velocity with usage of L2 norm statistic result. + Step1 and step2 can be merged with api provided by nvida + cudaLaunchCooperativeKernel: + - The thread quantity shall less than pyhsical SM limited threads + - Launche as thread-block can synchronizlly execute. */ + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, MergedMomentumLarsKernel, LARS_BLOCK_SIZE, + sizeof(MT) << 1); + + size_t total_numel = 0; + for (int i = 0; i < op_num; ++i) { + size_t temp_numel = param[i]->numel(); + total_numel += temp_numel; + lars_warpper.numel_arr[i] = temp_numel; + lars_warpper.p_arr[i] = param[i]->data(); + lars_warpper.g_arr[i] = grad[i]->data(); + lars_warpper.v_arr[i] = velocity[i]->data(); + lars_warpper.lr_arr[i] = learning_rate[i]->data(); + lars_warpper.p_out_arr[i] = + param_out[i]->mutable_data(ctx.GetPlace()); + lars_warpper.v_out_arr[i] = + velocity_out[i]->mutable_data(ctx.GetPlace()); + lars_warpper.weight_decay_arr[i] = static_cast(weight_decay_arr[i]); + } + int64_t avg_numel = total_numel / op_num; + LarsThreadConfig lars_thread_config(avg_numel, sm_num, + num_blocks_per_sm); + for (int i = 0; i < op_num; ++i) { + lars_warpper.repeat_arr[i] = + lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]); + } + if (multi_precision) { + for (int i = 0; i < op_num; ++i) { + lars_warpper.master_p_arr[i] = master_param[i]->data(); + lars_warpper.master_p_out_arr[i] = + master_param_out[i]->mutable_data(ctx.GetPlace()); + } + } + auto merged_buf = memory::Alloc(cuda_ctx, sizeof(lars_warpper)); + auto* merged_ptr = + reinterpret_cast*>(merged_buf->ptr()); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, cuda_ctx.GetPlace()), + reinterpret_cast(merged_ptr), platform::CPUPlace(), + reinterpret_cast(&lars_warpper), sizeof(lars_warpper), + cuda_ctx.stream()); + void* cuda_param[] = {reinterpret_cast(&merged_ptr), + reinterpret_cast(&p_buffer), + reinterpret_cast(&g_buffer), + reinterpret_cast(&op_num), + reinterpret_cast(&mu), + reinterpret_cast(&lars_coeff), + reinterpret_cast(&epsilon), + reinterpret_cast(&rescale_grad), + reinterpret_cast(&multi_precision)}; + // Lanuch all sm theads, and thead of each block synchronizedly cooperate. + cudaLaunchCooperativeKernel( + reinterpret_cast(MergedMomentumLarsKernel), + lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, + cuda_ctx.stream()); + } else { + auto* param_data = param[0]->data(); + auto* grad_data = grad[0]->data(); + auto* velocity_data = velocity[0]->data(); + auto* lr = learning_rate[0]->data(); + auto* param_out_data = param_out[0]->mutable_data(ctx.GetPlace()); + auto* velocity_out_data = + velocity_out[0]->mutable_data(ctx.GetPlace()); + const MT* master_param_data = + multi_precision ? master_param[0]->data() : nullptr; + MT* master_param_out_data = + multi_precision + ? master_param_out[0]->mutable_data(ctx.GetPlace()) + : nullptr; + int64_t numel = param[0]->numel(); + MT lars_weight_decay = weight_decay_arr[0]; + + // Figure out how many blocks can be active in each sm. + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, MomentumLarsKernel, LARS_BLOCK_SIZE, + sizeof(MT) << 1); + LarsThreadConfig lars_thread_config(numel, sm_num, + num_blocks_per_sm); + int repeat_times = lars_thread_config.GetRepeatTimes(numel); + int thresh = 0; + void* cuda_param[] = { + reinterpret_cast(¶m_data), + reinterpret_cast(&grad_data), + reinterpret_cast(&velocity_data), + reinterpret_cast(¶m_out_data), + reinterpret_cast(&velocity_out_data), + reinterpret_cast(&master_param_data), + reinterpret_cast(&master_param_out_data), + reinterpret_cast(&lr), + reinterpret_cast(&p_buffer), + reinterpret_cast(&g_buffer), + reinterpret_cast(&mu), + reinterpret_cast(&lars_coeff), + reinterpret_cast(&lars_weight_decay), + reinterpret_cast(&epsilon), + reinterpret_cast(&rescale_grad), + reinterpret_cast(&repeat_times), + reinterpret_cast(&thresh), // Just a placeholder + reinterpret_cast(&numel), + reinterpret_cast(&multi_precision)}; + // Lanuch all sm theads. + cudaLaunchCooperativeKernel( + reinterpret_cast(MomentumLarsKernel), + lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0, + cuda_ctx.stream()); + } +#else + for (int i = 0; i < op_num; ++i) { + const MT* master_param_data = + multi_precision ? master_param[i]->data() : nullptr; + MT* master_param_out_data = + multi_precision + ? master_param_out[i]->mutable_data(ctx.GetPlace()) + : nullptr; + SeparatedLarsMomentumOpCUDAKernel( + cuda_ctx, param[i]->data(), + param_out[i]->mutable_data(ctx.GetPlace()), + velocity[i]->data(), + velocity_out[i]->mutable_data(ctx.GetPlace()), grad[i]->data(), + learning_rate[i]->data(), p_buffer, g_buffer, mu, lars_coeff, + weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(), + master_param_data, master_param_out_data, multi_precision); + } #endif } }; diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.h b/paddle/fluid/operators/optimizers/lars_momentum_op.h old mode 100755 new mode 100644 index 55775bc08fb5ebc31cd231b8088a9798561fabfc..df4d7b9a0438bc103f262bb4a8971a3ee31d6ebb --- a/paddle/fluid/operators/optimizers/lars_momentum_op.h +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.h @@ -23,54 +23,48 @@ template class LarsMomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto velocity_out = ctx.Output("VelocityOut"); - auto param = ctx.Input("Param"); - auto velocity = ctx.Input("Velocity"); - auto learning_rate = ctx.Input("LearningRate"); - auto* grad_var = ctx.InputVar("Grad"); - // only support dense for now. - PADDLE_ENFORCE_EQ(grad_var->IsType(), true, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Grad").front(), - framework::ToTypeName(grad_var->Type()))); - auto grad = ctx.Input("Grad"); - - param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); - + auto param_out = ctx.MultiOutput("ParamOut"); + auto velocity_out = ctx.MultiOutput("VelocityOut"); + auto param = ctx.MultiInput("Param"); + auto velocity = ctx.MultiInput("Velocity"); + auto learning_rate = ctx.MultiInput("LearningRate"); + auto grad = ctx.MultiInput("Grad"); + auto weight_decay_arr = ctx.Attr>("lars_weight_decay"); T mu = static_cast(ctx.Attr("mu")); T lars_coeff = ctx.Attr("lars_coeff"); - T lars_weight_decay = ctx.Attr("lars_weight_decay"); T epsilon = ctx.Attr("epsilon"); - auto p_out = framework::EigenVector::Flatten(*param_out); - auto v_out = framework::EigenVector::Flatten(*velocity_out); + int op_num = param.size(); + for (int i = 0; i < op_num; ++i) { + auto* lr = learning_rate[i]->data(); + T lars_weight_decay = weight_decay_arr[i]; + param_out[i]->mutable_data(ctx.GetPlace()); + velocity_out[i]->mutable_data(ctx.GetPlace()); - auto p = framework::EigenVector::Flatten(*param); - auto v = framework::EigenVector::Flatten(*velocity); - auto g = framework::EigenVector::Flatten(*grad); - auto* lr = learning_rate->data(); + auto p_out = framework::EigenVector::Flatten(*(param_out[i])); + auto v_out = framework::EigenVector::Flatten(*(velocity_out[i])); + auto p = framework::EigenVector::Flatten(*(param[i])); + auto v = framework::EigenVector::Flatten(*(velocity[i])); + auto g = framework::EigenVector::Flatten(*(grad[i])); - framework::Tensor p_norm_t, g_norm_t; - p_norm_t.Resize({1}); - g_norm_t.Resize({1}); - p_norm_t.mutable_data(ctx.GetPlace()); - g_norm_t.mutable_data(ctx.GetPlace()); - auto ep_norm = framework::EigenScalar::From(p_norm_t); - auto eg_norm = framework::EigenScalar::From(g_norm_t); + framework::Tensor p_norm_t, g_norm_t; + p_norm_t.Resize({1}); + g_norm_t.Resize({1}); + p_norm_t.mutable_data(ctx.GetPlace()); + g_norm_t.mutable_data(ctx.GetPlace()); + auto ep_norm = framework::EigenScalar::From(p_norm_t); + auto eg_norm = framework::EigenScalar::From(g_norm_t); + ep_norm = p.square().sum().sqrt(); + eg_norm = g.square().sum().sqrt(); - ep_norm = p.square().sum().sqrt(); - eg_norm = g.square().sum().sqrt(); - T local_lr = lr[0]; - if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { - local_lr = lr[0] * lars_coeff * ep_norm(0) / - (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); + T local_lr = lr[0]; + if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { + local_lr = lr[0] * lars_coeff * ep_norm(0) / + (eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon); + } + v_out = v * mu + local_lr * (g + lars_weight_decay * p); + p_out = p - v_out; } - v_out = v * mu + local_lr * (g + lars_weight_decay * p); - p_out = p - v_out; } }; diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 24076e82b0365d21e7222a16cbc3d3462699f119..b81862adf5e65698b8d9c34fa9ea306aba8f9f11 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2066,7 +2066,7 @@ class LarsMomentumOptimizer(Optimizer): attrs = { "mu": self._momentum, "lars_coeff": self._lars_coeff, - "lars_weight_decay": _lars_weight_decay, + "lars_weight_decay": [_lars_weight_decay], "multi_precision": find_master, "rescale_grad": self._rescale_grad } diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py index e4cc3682d1a24f80bb52d57c5734ac0686bfeb63..bee6acf732460b46f0b322bc20d67017696b4af0 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py @@ -103,7 +103,7 @@ class TestFleetLarsMetaOptimizer(unittest.TestCase): 'op_role_var')[0] or ".b" in op.attr('op_role_var')[0]) ] for op in ops_without_wd: - self.assertEqual(op.attr('lars_weight_decay'), 0) + self.assertEqual(op.attr('lars_weight_decay')[0], 0) def test_lars_apply_with_amp(self): role = role_maker.PaddleCloudRoleMaker(is_collective=True) diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index b42de853c00d54c6f8ed30642c01e4ca93443fac..34e057a5a8a612114169989f0693c83a4e71111d 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -138,50 +138,70 @@ class TestMomentumOp2(OpTest): "core is not compiled with CUDA") class TestLarsMomentumOpWithMP(OpTest): def setUp(self): + self.config() self.op_type = "lars_momentum" - - master_param = np.random.random((123, 321)).astype("float32") - param = master_param.astype("float16") - grad = np.random.random((123, 321)).astype("float16") - velocity = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.001]).astype("float32") mu = 0.0001 lars_coeff = 0.001 lars_weight_decay = 0.0005 rescale_grad = 1.0 + params = [] + grads = [] + velocitys = [] + learning_rates = [] + master_params = [] + param_outs = [] + velocity_outs = [] + master_param_outs = [] + for i in range(self.params_num): + master_param = np.random.random((123, 321)).astype("float32") + param = master_param.astype("float16") + grad = np.random.random((123, 321)).astype("float16") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + + fp32_grad = grad.astype("float32") + pnorm = np.sqrt(np.square(master_param).sum()) + gnorm = np.sqrt(np.square(fp32_grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * pnorm) + fp32_grad = fp32_grad * rescale_grad + velocity_out = mu * velocity + local_lr * ( + fp32_grad + lars_weight_decay * master_param) + p_new = master_param - velocity_out + param_out = p_new.astype("float16") + master_param_out = p_new + + params.append(("SubParam_" + str(i), param)) + grads.append(("SubGrad_" + str(i), grad)) + velocitys.append(("SubVelocity_" + str(i), velocity)) + learning_rates.append(("SubLearning_rate_" + str(i), learning_rate)) + velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out)) + param_outs.append(("SubParam_out_" + str(i), param_out)) + master_params.append(("SubMasterParam_" + str(i), master_param)) + master_param_outs.append( + ("SubMasterParamOut_" + str(i), master_param_out)) + self.inputs = { - 'Param': param, - 'Grad': grad, - 'Velocity': velocity, - 'LearningRate': learning_rate, - 'MasterParam': master_param, + 'Param': params, + 'Grad': grads, + 'Velocity': velocitys, + 'LearningRate': learning_rates, + 'MasterParam': master_params, } self.attrs = { 'mu': mu, 'lars_coeff': lars_coeff, - 'lars_weight_decay': lars_weight_decay, + 'lars_weight_decay': [lars_weight_decay], 'multi_precision': True, 'rescale_grad': rescale_grad } - fp32_grad = grad.astype("float32") - pnorm = np.sqrt(np.square(master_param).sum()) - gnorm = np.sqrt(np.square(fp32_grad).sum()) - local_lr = learning_rate * lars_coeff * pnorm / ( - gnorm + lars_weight_decay * pnorm) - fp32_grad = fp32_grad * rescale_grad - velocity_out = mu * velocity + local_lr * (fp32_grad + lars_weight_decay - * master_param) - p_new = master_param - velocity_out - param_out = p_new.astype("float16") - master_param_out = p_new - self.outputs = { - 'ParamOut': param_out, - 'VelocityOut': velocity_out, - 'MasterParamOut': master_param_out + 'ParamOut': param_outs, + 'VelocityOut': velocity_outs, + 'MasterParamOut': master_param_outs } def test_check_output(self): @@ -191,46 +211,65 @@ class TestLarsMomentumOpWithMP(OpTest): if core.is_float16_supported(place): self.check_output_with_place(place) + def config(self): + self.params_num = 1 + class TestLarsMomentumOp(OpTest): def setUp(self): + self.config() self.op_type = "lars_momentum" - - param = np.random.random((123, 321)).astype("float32") - grad = np.random.random((123, 321)).astype("float32") - velocity = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.001]).astype("float32") mu = 0.0001 lars_coeff = 0.001 lars_weight_decay = 0.0005 + params = [] + grads = [] + velocitys = [] + param_outs = [] + velocity_outs = [] + learning_rates = [] + for i in range(self.params_num): + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + velocity = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.001]).astype("float32") + pnorm = np.sqrt(np.square(param).sum()) + gnorm = np.sqrt(np.square(grad).sum()) + local_lr = learning_rate * lars_coeff * pnorm / ( + gnorm + lars_weight_decay * param) + velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay + * param) + param_out = param - velocity_out + + params.append(("SubParam_" + str(i), param)) + grads.append(("SubGrad_" + str(i), grad)) + velocitys.append(("SubVelocity_" + str(i), velocity)) + learning_rates.append(("SubLearning_rate_" + str(i), learning_rate)) + velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out)) + param_outs.append(("SubParam_out_" + str(i), param_out)) + self.inputs = { - 'Param': param, - 'Grad': grad, - 'Velocity': velocity, - 'LearningRate': learning_rate + 'Param': params, + 'Grad': grads, + 'Velocity': velocitys, + 'LearningRate': learning_rates } self.attrs = { 'mu': mu, 'lars_coeff': lars_coeff, - 'lars_weight_decay': lars_weight_decay + 'lars_weight_decay': [lars_weight_decay] } - - pnorm = np.sqrt(np.square(param).sum()) - gnorm = np.sqrt(np.square(grad).sum()) - local_lr = learning_rate * lars_coeff * pnorm / ( - gnorm + lars_weight_decay * param) - velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay * - param) - param_out = param - velocity_out - - self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} + self.outputs = {'ParamOut': param_outs, 'VelocityOut': velocity_outs} def test_check_output(self): paddle.enable_static() self.check_output() + def config(self): + self.params_num = 1 + class TestSparseMomentumOp(unittest.TestCase): def setUp(self):