未验证 提交 0c31579c 编写于 作者: L limingshu 提交者: GitHub

Merge lars op (#35476)

* A leap of try for cudaLaunchCooperativeKernel

* fix bugs

* Totally replace the lar cuda kernel

* Fix bugs

* a test for lars merge

* Adding las_op_momentum infer_shape

* Fix codes

* use avg_numel instead of max_numel to acquire grid num

* modify unittest files about lars op

* Finally converge when merged-lars works

* fix ctest files

* add merged_operation kernel when cuda version is older than 11

* Fix code style

* fix ctest failure

* fix error

* fix all ctest error and change lars compute code of cpu

* fix bugs on v100.

* revert python modififation about lars

* revert python modification codes
上级 24418479
......@@ -13,46 +13,158 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
namespace paddle {
namespace operators {
class LarsMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Velocity"), "Input", "Velocity",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("LearningRate"), "Input", "LearningRate",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("ParamOut"), "Output", "ParamOut",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"), "Output", "VelocityOut",
"LarsMomentum");
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->GetInputsVarType("Param").front()));
auto lr_dims = ctx->GetInputsDim("LearningRate");
auto grad_dim = ctx->GetInputsDim("Grad");
auto param_dim = ctx->GetInputsDim("Param");
auto velocity_dim = ctx->GetInputsDim("Velocity");
auto lars_weight_decays =
ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
auto multi_precision = ctx->Attrs().Get<bool>("multi_precision");
PADDLE_ENFORCE_EQ(
param_dim.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(), grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(), velocity_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(), velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decays.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decays.size(), grad_dim.size()));
if (multi_precision) {
OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam",
"LarsMomentumMultiPrecision");
OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"), "Output",
"MasterParamOut", "LarsMomentumMultiPrecision");
}
for (size_t i = 0; i < lr_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
framework::product(lr_dims[i])));
}
for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i],
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx->Inputs("Grad")[i].front(),
ctx->GetInputsVarType("Grad")[i]));
PADDLE_ENFORCE_EQ(
param_dim[i], grad_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i], grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i], velocity_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i], velocity_dim[i]));
}
ctx->SetOutputsDim("ParamOut", param_dim);
ctx->SetOutputsDim("VelocityOut", param_dim);
if (ctx->HasOutputs("MasterParamOut")) {
ctx->SetOutputsDim("MasterParamOut", param_dim);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(LoDTensor, default LoDTensor<float>) "
"Input parameter that has to be updated");
"Input parameter that has to be updated")
.AsDuplicable();
AddInput("Grad",
"(LoDTensor, default LoDTensor<float>) "
"Input gradient of the parameter");
"Input gradient of the parameter")
.AsDuplicable();
AddInput("Velocity",
"(LoDTensor, default LoDTensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
"that has to be updated")
.AsDuplicable();
AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
"Input learning rate")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDuplicable()
.AsDispensable();
AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
"It shared memory with Input(Param).")
.AsDuplicable();
AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
"It shared memory with Input(Velocity).")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDuplicable()
.AsDispensable();
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
.SetDefault(0.001);
AddAttr<float>("lars_weight_decay",
"(float, default 0.0005) LARS weight decay")
.SetDefault(0.0005);
AddAttr<std::vector<float>>(
"lars_weight_decay",
"(std::vector<float>, default 0.0005) LARS weight decay params")
.SetDefault({0.0005});
AddAttr<float>("epsilon",
"(float, default 0.0) epsilon to avoid Division by Zero.")
.SetDefault(0.0);
......@@ -96,7 +208,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker,
lars_momentum, ops::LarsMomentumOp, ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LarsMomentumOpVarTypeInference);
......
......@@ -18,18 +18,8 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/platform/fast_divmod.h"
#if defined(__NVCC__) && CUDA_VERSION >= 11000
/* Once CUDA_VERSION is beyond 11.0, cooperative_groups can be involved in
without adding --rdc=true compile flag, then L2_norm cuda kernel can be
set as a __device__ kernel rather than global kernel. On the contrary,
the compile flag shall be set in old version, which may affect the cuda
kernel performance in paddle, consequently, L2_norm kernel shall be set
as a __global__ kernel.
*/
#if CUDA_VERSION >= 11000
#include <cooperative_groups.h>
#define LARS_FUNCTION_FLAG __device__
#else
#define LARS_FUNCTION_FLAG __global__
#endif
#ifdef __HIPCC__
......@@ -38,6 +28,8 @@ limitations under the License. */
#define LARS_BLOCK_SIZE 512
#endif
#define LARS_MAX_MERGED_OPS 150
namespace paddle {
namespace operators {
......@@ -53,6 +45,43 @@ __device__ __forceinline__ double Fma(double x, double y, double z) {
return fma(x, y, z);
}
template <typename T>
class LarsThreadConfig {
public:
int grid_for_norm;
int grid_for_lars;
#if CUDA_VERSION >= 11000
private:
int grid_stride;
public:
explicit LarsThreadConfig(int64_t numel, int sm_num, int num_blocks_per_sm) {
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
grid_for_lars =
std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE);
grid_stride = LARS_BLOCK_SIZE * grid_for_lars;
}
int GetRepeatTimes(int64_t numel) {
return (numel + grid_stride - 1) / grid_stride - 1;
}
#else
int repeat_times;
explicit LarsThreadConfig(const int64_t numel) {
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
grid_for_norm = std::min(grid, LARS_BLOCK_SIZE);
const int grid_stride = grid_for_norm * LARS_BLOCK_SIZE;
repeat_times = (numel + grid_stride - 1) / grid_stride - 1;
// Determine to read 4 fp16 or float data once, but 2 double data once.
grid_for_lars =
std::is_same<double, T>::value
? (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1)
: (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2);
}
#endif
};
template <typename T, typename MT, int VecSize, bool IsAmp = false>
__device__ inline void VectorizeLarsUpdate(
const T* __restrict__ grad, const MT* __restrict__ param,
......@@ -85,7 +114,6 @@ __device__ inline void VectorizeLarsUpdate(
VecType grad_data = grad_vec[i];
VecMType param_data = param_vec[i];
VecMType velocity_data = velocity_vec[i];
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
MT grad_val = static_cast<MT>(grad_data[j]) * rescale_grad;
......@@ -116,41 +144,49 @@ __device__ inline void VectorizeLarsUpdate(
}
}
#if CUDA_VERSION >= 11000
/* Once CUDA_VERSION is beyond 11, cooperative_groups can be involved in without
--rdc=true compile flag, then L2_norm kernel can be set with __device__ and
cooperative_groups::grid_group also can be involved. Otherwise, adding this
flag may affect much, L2_norm kernel shall be set with __global__.*/
// TODO(limingshu): declaration of cooperative_groups wapper is invalid in host.
template <typename T, typename MT>
__forceinline__ __device__ void L2NormKernel(
const cooperative_groups::grid_group* cg,
#else
template <typename T, typename MT>
LARS_FUNCTION_FLAG void L2NormKernel(
__global__ void L2NormKernel(
#endif
const T* __restrict__ p_data, const T* __restrict__ g_data,
MT* __restrict__ p_buffer, MT* __restrict__ g_buffer,
const int repeat_times, const int64_t numel, const MT rescale_grad,
MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel,
const int repeat_times, const MT rescale_grad, const int thresh = 0,
MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) {
__shared__ MT s_buffer[2];
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
const MT rescale_grad_pow = rescale_grad * rescale_grad;
__shared__ MT s_buffer[2];
const MT rescale_pow = rescale_grad * rescale_grad;
s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0);
MT p_tmp_val = static_cast<MT>(0);
MT g_tmp_val = static_cast<MT>(0);
MT p_tmp = static_cast<MT>(0);
MT g_tmp = static_cast<MT>(0);
if (repeat_times == 0) {
if (tid < numel) {
p_tmp_val = static_cast<MT>(p_data[tid]);
g_tmp_val = static_cast<MT>(g_data[tid]);
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
}
s_buffer[0] += math::blockReduceSum<MT>(p_tmp_val * p_tmp_val, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp_val * g_tmp_val, FINAL_MASK);
s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
} else {
/* To avoid occupy too much temp buffer. Hence, slice the whole data into 2
parts, the front of them whose quantity is excatly multiple of grid-thread
number, and this part of data is delt in for loop, the rest of data is delt
with another step to avoid visiting data address beyond bound. */
/* Avoid occupy too much temp buffer. Slice the whole data into 2 parts,
the front of data whose quantity is excatly multiple of grid-thread
number, and delt in for loop, the rest is delt with another step. */
for (int i = 0; i < repeat_times; ++i) {
p_tmp_val = static_cast<MT>(p_data[tid]);
g_tmp_val = static_cast<MT>(g_data[tid]);
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
tid += grid_stride;
s_buffer[0] +=
math::blockReduceSum<MT>(p_tmp_val * p_tmp_val, FINAL_MASK);
s_buffer[1] +=
math::blockReduceSum<MT>(g_tmp_val * g_tmp_val, FINAL_MASK);
s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
__syncthreads();
}
MT p_val = 0;
......@@ -168,69 +204,46 @@ LARS_FUNCTION_FLAG void L2NormKernel(
p_buffer[blockIdx.x] = s_buffer[0];
g_buffer[blockIdx.x] = s_buffer[1];
}
#if CUDA_VERSION >= 11000
// Grid sync for completely writring partial result back to gloabl memory
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
cg.sync();
MT p_partial_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
MT g_partial_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
*p_n = Sqrt(math::blockReduceSum<MT>(p_partial_sum, FINAL_MASK));
*g_n = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(g_partial_sum, FINAL_MASK));
cg->sync(); // Grid sync for writring partial result to gloabl memory
MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
*p_n = Sqrt(math::blockReduceSum<MT>(p_part_sum, FINAL_MASK));
*g_n = Sqrt(rescale_pow * math::blockReduceSum<MT>(g_part_sum, FINAL_MASK));
#endif
}
template <typename T, typename MT>
__global__ void MomentumLarsKernel(
__forceinline__ __device__ void MomentumUpdate(
const T* __restrict__ param, const T* __restrict__ grad,
const MT* __restrict__ velocity, T* param_out, MT* velocity_out,
const MT* __restrict__ master_param, MT* __restrict__ master_param_out,
const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff,
const MT lars_weight_decay, const MT epsilon, const MT rescale_grad,
const int repeat_times, const int thresh, const int64_t numel) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
#if CUDA_VERSION >= 11000
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(param, grad, p_buffer, g_buffer, repeat_times, numel,
rescale_grad, &param_norm, &grad_norm);
#else
const MT rescale_grad_pow = rescale_grad * rescale_grad;
MT param_parital_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT grad_parital_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
__syncthreads();
MT param_norm =
Sqrt(math::blockReduceSum<MT>(param_parital_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(grad_parital_norm, FINAL_MASK));
#endif
const MT* __restrict__ learning_rate, const MT mu,
const MT lars_weight_decay, const MT lars_coeff, const MT epsilon,
const MT rescale_grad, const MT param_norm, const MT grad_norm,
const int tid, const int grid_stride, const int64_t numel,
const bool is_amp) {
const MT lr = learning_rate[0];
MT local_lr = lr;
if (lars_weight_decay > static_cast<MT>(0)) {
local_lr = lr * lars_coeff * param_norm /
(Fma(lars_weight_decay, param_norm, grad_norm) + epsilon);
(fma(lars_weight_decay, param_norm, grad_norm) + epsilon);
}
if (master_param_out) {
VectorizeLarsUpdate<T, MT, 4, true>(grad, master_param, velocity, param_out,
velocity_out, mu, local_lr,
lars_weight_decay, rescale_grad, tid,
grid_stride, numel, master_param_out);
if (is_amp) {
VectorizeLarsUpdate<T, MT, /*VecSize=*/4, /*IsAmp=*/true>(
grad, master_param, velocity, param_out, velocity_out, mu, local_lr,
lars_weight_decay, rescale_grad, tid, grid_stride, numel,
master_param_out);
} else {
if (std::is_same<T, float>::value ||
std::is_same<T, paddle::platform::float16>::value) {
// As for multiple-precision, type T and MT cannot be more than fp16 or
// fp32, Then, the maximum data IO size could be set to 4.
VectorizeLarsUpdate<T, MT, 4, false>(
/* TODO(limingshu): pointer cast may damage memory accessing for fp16 */
VectorizeLarsUpdate<T, MT, /*VecSize=*/4, /*IsAmp=*/false>(
grad, reinterpret_cast<const MT*>(param), velocity, param_out,
velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid,
grid_stride, numel);
} else {
VectorizeLarsUpdate<T, MT, 2, false>(
VectorizeLarsUpdate<T, MT, /*VecSize=*/2, /*IsAmp=*/false>(
grad, reinterpret_cast<const MT*>(param), velocity, param_out,
velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid,
grid_stride, numel);
......@@ -238,94 +251,235 @@ __global__ void MomentumLarsKernel(
}
}
#if CUDA_VERSION >= 11000
template <typename T, typename MT>
struct LarsParamWarpper {
int64_t numel_arr[LARS_MAX_MERGED_OPS];
int repeat_arr[LARS_MAX_MERGED_OPS];
const T* __restrict__ p_arr[LARS_MAX_MERGED_OPS];
const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ v_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ master_p_arr[LARS_MAX_MERGED_OPS];
T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS];
MT weight_decay_arr[LARS_MAX_MERGED_OPS];
};
template <typename T, typename MT>
__global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT>* lars_warpper,
MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer,
const int op_num, const MT mu,
const MT lars_coeff, const MT epsilon,
const MT rescale_grad,
const bool is_amp) {
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
for (int i = 0; i < op_num; ++i) {
int numel = lars_warpper->numel_arr[i];
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, lars_warpper->p_arr[i], lars_warpper->g_arr[i],
p_buffer, g_buffer, numel, lars_warpper->repeat_arr[i],
rescale_grad, 0, &param_norm, &grad_norm);
MomentumUpdate<T, MT>(
lars_warpper->p_arr[i], lars_warpper->g_arr[i],
lars_warpper->v_out_arr[i], lars_warpper->p_out_arr[i],
lars_warpper->v_out_arr[i], lars_warpper->master_p_arr[i],
lars_warpper->master_p_out_arr[i], lars_warpper->lr_arr[i], mu,
lars_warpper->weight_decay_arr[i], lars_coeff, epsilon, rescale_grad,
param_norm, grad_norm, tid, grid_stride, numel, is_amp);
}
}
#endif
template <typename T, typename MT>
__global__ void MomentumLarsKernel(
const T* __restrict__ param, const T* __restrict__ grad,
const MT* __restrict__ velocity, T* param_out, MT* velocity_out,
const MT* __restrict__ master_param, MT* __restrict__ master_param_out,
const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff,
const MT lars_weight_decay, const MT epsilon, const MT rescale_grad,
const int repeat_times, const int thresh, const int64_t numel,
const bool is_amp) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
#if CUDA_VERSION >= 11000
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times,
rescale_grad, gridDim.x, &param_norm, &grad_norm);
#else
const MT rescale_grad_pow = rescale_grad * rescale_grad;
MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
__syncthreads();
MT param_norm = Sqrt(math::blockReduceSum<MT>(param_part_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(grad_part_norm, FINAL_MASK));
#endif
MomentumUpdate<T, MT>(param, grad, velocity, param_out, velocity_out,
master_param, master_param_out, learning_rate, mu,
lars_weight_decay, lars_coeff, epsilon, rescale_grad,
param_norm, grad_norm, tid, grid_stride, numel, is_amp);
}
template <typename T, typename MT>
inline void SeparatedLarsMomentumOpCUDAKernel(
const platform::CUDADeviceContext& cuda_ctx, const T* param_data,
T* param_out_data, const MT* velocity_data, MT* velocity_out_data,
const T* grad_data, const MT* lr, MT* p_buffer, MT* g_buffer, const MT mu,
const MT lars_coeff, const MT weight_decay, const MT epsilon,
const MT rescale_grad, const int64_t numel, const MT* master_param_data,
MT* master_out_data, const bool is_amp) {
LarsThreadConfig<T> lars_thread_config(numel);
L2NormKernel<T, MT><<<lars_thread_config.grid_for_norm, LARS_BLOCK_SIZE, 0,
cuda_ctx.stream()>>>(
param_data, grad_data, p_buffer, g_buffer, numel,
lars_thread_config.repeat_times, rescale_grad);
MomentumLarsKernel<T, MT><<<lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE,
0, cuda_ctx.stream()>>>(
param_data, grad_data, velocity_data, param_out_data, velocity_out_data,
master_param_data, master_out_data, lr, p_buffer, g_buffer, mu,
lars_coeff, weight_decay, epsilon, rescale_grad, 0,
lars_thread_config.grid_for_norm, numel, is_amp);
}
template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MT = MultiPrecisionType<T>;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto grad = ctx.Input<framework::LoDTensor>("Grad");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
int64_t numel = param->numel();
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
const framework::Tensor* master_param = nullptr;
framework::Tensor* master_param_out = nullptr;
const MT* master_param_data = nullptr;
MT* master_param_out_data = nullptr;
int num_blocks_per_sm = 0;
bool multi_precision = ctx.Attr<bool>("multi_precision");
auto& cuda_ctx = ctx.template device_context<platform::CUDADeviceContext>();
int sm_num = cuda_ctx.GetSMCount();
framework::Tensor tmp_buffer_t =
ctx.AllocateTmpTensor<MT, platform::CUDADeviceContext>(
{LARS_BLOCK_SIZE << 1}, cuda_ctx);
auto* p_buffer = tmp_buffer_t.mutable_data<MT>(ctx.GetPlace());
auto* g_buffer = p_buffer + LARS_BLOCK_SIZE;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
master_param_data = master_param->data<MT>();
master_param_out_data =
master_param_out->mutable_data<MT>(ctx.GetPlace());
}
MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
MT lars_weight_decay =
static_cast<MT>(ctx.Attr<float>("lars_weight_decay"));
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
auto* param_data = param->data<T>();
auto* grad_data = grad->data<T>();
auto* velocity_data = velocity->data<MT>();
auto* lr = learning_rate->data<MT>();
auto& cuda_ctx = ctx.template device_context<platform::CUDADeviceContext>();
T* param_out_data = param_out->mutable_data<T>(ctx.GetPlace());
MT* velocity_out_data = velocity_out->mutable_data<MT>(ctx.GetPlace());
auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
auto param = ctx.MultiInput<framework::LoDTensor>("Param");
auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
auto learning_rate = ctx.MultiInput<framework::LoDTensor>("LearningRate");
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
auto master_param_out =
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");
int op_num = grad.size();
#if CUDA_VERSION >= 11000
/*
Once model trainning with lars optimizer, whose principal implementation
is achieved by following two steps:
if (op_num > 1) {
LarsParamWarpper<T, MT> lars_warpper;
PADDLE_ENFORCE_LT(
op_num, LARS_MAX_MERGED_OPS,
platform::errors::InvalidArgument(
"The maximum number of merged-ops supported is (%d), but"
"lars op required for trainning this model is (%d)\n",
LARS_MAX_MERGED_OPS, op_num));
/* Implementation of lars optimizer consists of following two steps:
1. Figure out the L2 norm statistic result of grad data and param data.
2. Update param and velocity data with usage of L2 norm statistic result.
Orignally, these two steps were fulfilled by respective eigen function and
cuda kernel, however the overhead of eigen function occupied much ratio in
total, consequently affect the performance of lars op, make it necessary
to combine 2 steps into one cuda kernel.
Since the step1 is l2 norm statistic, grid level reduce is needed. To
achieve this and continuous calculation of step 2 in only one global
lanuch, essential basis is to control all grid-threads while running. Apart
from normal lanuch form, cuda9.0 provides `cudaLaunchCooperativeKernel`
api :
2. Update param and velocity with usage of L2 norm statistic result.
Step1 and step2 can be merged with api provided by nvida
cudaLaunchCooperativeKernel:
- The thread quantity shall less than pyhsical SM limited threads
- Launches a device function where thread blocks can cooperate and
synchronize as they execute.
*/
- Launche as thread-block can synchronizlly execute. */
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, MergedMomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
sizeof(MT) << 1);
size_t total_numel = 0;
for (int i = 0; i < op_num; ++i) {
size_t temp_numel = param[i]->numel();
total_numel += temp_numel;
lars_warpper.numel_arr[i] = temp_numel;
lars_warpper.p_arr[i] = param[i]->data<T>();
lars_warpper.g_arr[i] = grad[i]->data<T>();
lars_warpper.v_arr[i] = velocity[i]->data<MT>();
lars_warpper.lr_arr[i] = learning_rate[i]->data<MT>();
lars_warpper.p_out_arr[i] =
param_out[i]->mutable_data<T>(ctx.GetPlace());
lars_warpper.v_out_arr[i] =
velocity_out[i]->mutable_data<MT>(ctx.GetPlace());
lars_warpper.weight_decay_arr[i] = static_cast<MT>(weight_decay_arr[i]);
}
int64_t avg_numel = total_numel / op_num;
LarsThreadConfig<float> lars_thread_config(avg_numel, sm_num,
num_blocks_per_sm);
for (int i = 0; i < op_num; ++i) {
lars_warpper.repeat_arr[i] =
lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]);
}
if (multi_precision) {
for (int i = 0; i < op_num; ++i) {
lars_warpper.master_p_arr[i] = master_param[i]->data<MT>();
lars_warpper.master_p_out_arr[i] =
master_param_out[i]->mutable_data<MT>(ctx.GetPlace());
}
}
auto merged_buf = memory::Alloc(cuda_ctx, sizeof(lars_warpper));
auto* merged_ptr =
reinterpret_cast<LarsParamWarpper<T, MT>*>(merged_buf->ptr());
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, cuda_ctx.GetPlace()),
reinterpret_cast<void*>(merged_ptr), platform::CPUPlace(),
reinterpret_cast<void*>(&lars_warpper), sizeof(lars_warpper),
cuda_ctx.stream());
void* cuda_param[] = {reinterpret_cast<void*>(&merged_ptr),
reinterpret_cast<void*>(&p_buffer),
reinterpret_cast<void*>(&g_buffer),
reinterpret_cast<void*>(&op_num),
reinterpret_cast<void*>(&mu),
reinterpret_cast<void*>(&lars_coeff),
reinterpret_cast<void*>(&epsilon),
reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads, and thead of each block synchronizedly cooperate.
cudaLaunchCooperativeKernel(
reinterpret_cast<void*>(MergedMomentumLarsKernel<T, MT>),
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0,
cuda_ctx.stream());
} else {
auto* param_data = param[0]->data<T>();
auto* grad_data = grad[0]->data<T>();
auto* velocity_data = velocity[0]->data<MT>();
auto* lr = learning_rate[0]->data<MT>();
auto* param_out_data = param_out[0]->mutable_data<T>(ctx.GetPlace());
auto* velocity_out_data =
velocity_out[0]->mutable_data<MT>(ctx.GetPlace());
const MT* master_param_data =
multi_precision ? master_param[0]->data<MT>() : nullptr;
MT* master_param_out_data =
multi_precision
? master_param_out[0]->mutable_data<MT>(ctx.GetPlace())
: nullptr;
int64_t numel = param[0]->numel();
MT lars_weight_decay = weight_decay_arr[0];
// Figure out how many blocks can be active in each sm.
int num_blocks_per_sm = 0;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm,
MomentumLarsKernel<T, MT>,
LARS_BLOCK_SIZE, sizeof(MT));
int sm_num = cuda_ctx.GetSMCount();
int grid_real =
std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE);
framework::Tensor tmp_buffer_t =
ctx.AllocateTmpTensor<MT, platform::CUDADeviceContext>(
{LARS_BLOCK_SIZE << 1}, cuda_ctx);
auto* p_buffer = tmp_buffer_t.mutable_data<MT>(ctx.GetPlace());
auto* g_buffer = p_buffer + LARS_BLOCK_SIZE;
int grid_stride = LARS_BLOCK_SIZE * grid;
int repeat_times = (numel + grid_stride - 1) / grid_stride - 1;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, MomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
sizeof(MT) << 1);
LarsThreadConfig<float> lars_thread_config(numel, sm_num,
num_blocks_per_sm);
int repeat_times = lars_thread_config.GetRepeatTimes(numel);
int thresh = 0;
// Uniform kernel parameter for cudaLaunchCooperativeKernel
void* cuda_param[] = {
reinterpret_cast<void*>(&param_data),
reinterpret_cast<void*>(&grad_data),
......@@ -344,38 +498,31 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&repeat_times),
reinterpret_cast<void*>(&thresh), // Just a placeholder
reinterpret_cast<void*>(&numel)};
reinterpret_cast<void*>(&numel),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads.
cudaLaunchCooperativeKernel(
reinterpret_cast<void*>(MomentumLarsKernel<T, MT>), grid_real,
LARS_BLOCK_SIZE, cuda_param, 0, cuda_ctx.stream());
reinterpret_cast<void*>(MomentumLarsKernel<T, MT>),
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0,
cuda_ctx.stream());
}
#else
// Determine to read 4 fp16 or float data once, but 2 double data once.
int grid_lars =
sizeof(T) < sizeof(double)
? (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2)
: (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1);
int grid_norm = std::min(grid, LARS_BLOCK_SIZE);
framework::Tensor p_buffer_t =
ctx.AllocateTmpTensor<MT, platform::CUDADeviceContext>(
{LARS_BLOCK_SIZE << 1}, cuda_ctx);
auto* p_buffer = p_buffer_t.mutable_data<MT>(ctx.GetPlace());
auto* g_buffer = p_buffer + LARS_BLOCK_SIZE;
const int grid_stride = LARS_BLOCK_SIZE * grid_norm;
const int repeat_times = (numel + grid_stride - 1) / grid_stride - 1;
L2NormKernel<T, MT><<<grid_norm, LARS_BLOCK_SIZE, 0, cuda_ctx.stream()>>>(
param_data, grad_data, p_buffer, g_buffer, repeat_times, numel,
rescale_grad);
MomentumLarsKernel<
T, MT><<<grid_lars, LARS_BLOCK_SIZE, 0, cuda_ctx.stream()>>>(
param_data, grad_data, velocity_data, param_out_data, velocity_out_data,
master_param_data, master_param_out_data, lr, p_buffer, g_buffer, mu,
lars_coeff, lars_weight_decay, epsilon, rescale_grad, 0, grid_norm,
numel); // 0 is just a placeholder.
for (int i = 0; i < op_num; ++i) {
const MT* master_param_data =
multi_precision ? master_param[i]->data<MT>() : nullptr;
MT* master_param_out_data =
multi_precision
? master_param_out[i]->mutable_data<MT>(ctx.GetPlace())
: nullptr;
SeparatedLarsMomentumOpCUDAKernel<T, MT>(
cuda_ctx, param[i]->data<T>(),
param_out[i]->mutable_data<T>(ctx.GetPlace()),
velocity[i]->data<MT>(),
velocity_out[i]->mutable_data<MT>(ctx.GetPlace()), grad[i]->data<T>(),
learning_rate[i]->data<MT>(), p_buffer, g_buffer, mu, lars_coeff,
weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(),
master_param_data, master_param_out_data, multi_precision);
}
#endif
}
};
......
......@@ -23,36 +23,29 @@ template <typename T>
class LarsMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
auto* grad_var = ctx.InputVar("Grad");
// only support dense for now.
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
auto grad = ctx.Input<framework::LoDTensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
auto param = ctx.MultiInput<framework::LoDTensor>("Param");
auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
auto learning_rate = ctx.MultiInput<framework::LoDTensor>("LearningRate");
auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
T epsilon = ctx.Attr<float>("epsilon");
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
int op_num = param.size();
for (int i = 0; i < op_num; ++i) {
auto* lr = learning_rate[i]->data<T>();
T lars_weight_decay = weight_decay_arr[i];
param_out[i]->mutable_data<T>(ctx.GetPlace());
velocity_out[i]->mutable_data<T>(ctx.GetPlace());
auto p = framework::EigenVector<T>::Flatten(*param);
auto v = framework::EigenVector<T>::Flatten(*velocity);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto* lr = learning_rate->data<T>();
auto p_out = framework::EigenVector<T>::Flatten(*(param_out[i]));
auto v_out = framework::EigenVector<T>::Flatten(*(velocity_out[i]));
auto p = framework::EigenVector<T>::Flatten(*(param[i]));
auto v = framework::EigenVector<T>::Flatten(*(velocity[i]));
auto g = framework::EigenVector<T>::Flatten(*(grad[i]));
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
......@@ -61,9 +54,9 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
ep_norm = p.square().sum().sqrt();
eg_norm = g.square().sum().sqrt();
T local_lr = lr[0];
if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) {
local_lr = lr[0] * lars_coeff * ep_norm(0) /
......@@ -72,6 +65,7 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
p_out = p - v_out;
}
}
};
} // namespace operators
......
......@@ -2066,7 +2066,7 @@ class LarsMomentumOptimizer(Optimizer):
attrs = {
"mu": self._momentum,
"lars_coeff": self._lars_coeff,
"lars_weight_decay": _lars_weight_decay,
"lars_weight_decay": [_lars_weight_decay],
"multi_precision": find_master,
"rescale_grad": self._rescale_grad
}
......
......@@ -103,7 +103,7 @@ class TestFleetLarsMetaOptimizer(unittest.TestCase):
'op_role_var')[0] or ".b" in op.attr('op_role_var')[0])
]
for op in ops_without_wd:
self.assertEqual(op.attr('lars_weight_decay'), 0)
self.assertEqual(op.attr('lars_weight_decay')[0], 0)
def test_lars_apply_with_amp(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
......
......@@ -138,33 +138,27 @@ class TestMomentumOp2(OpTest):
"core is not compiled with CUDA")
class TestLarsMomentumOpWithMP(OpTest):
def setUp(self):
self.config()
self.op_type = "lars_momentum"
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
rescale_grad = 1.0
params = []
grads = []
velocitys = []
learning_rates = []
master_params = []
param_outs = []
velocity_outs = []
master_param_outs = []
for i in range(self.params_num):
master_param = np.random.random((123, 321)).astype("float32")
param = master_param.astype("float16")
grad = np.random.random((123, 321)).astype("float16")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
rescale_grad = 1.0
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate,
'MasterParam': master_param,
}
self.attrs = {
'mu': mu,
'lars_coeff': lars_coeff,
'lars_weight_decay': lars_weight_decay,
'multi_precision': True,
'rescale_grad': rescale_grad
}
fp32_grad = grad.astype("float32")
pnorm = np.sqrt(np.square(master_param).sum())
......@@ -172,16 +166,42 @@ class TestLarsMomentumOpWithMP(OpTest):
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * pnorm)
fp32_grad = fp32_grad * rescale_grad
velocity_out = mu * velocity + local_lr * (fp32_grad + lars_weight_decay
* master_param)
velocity_out = mu * velocity + local_lr * (
fp32_grad + lars_weight_decay * master_param)
p_new = master_param - velocity_out
param_out = p_new.astype("float16")
master_param_out = p_new
params.append(("SubParam_" + str(i), param))
grads.append(("SubGrad_" + str(i), grad))
velocitys.append(("SubVelocity_" + str(i), velocity))
learning_rates.append(("SubLearning_rate_" + str(i), learning_rate))
velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out))
param_outs.append(("SubParam_out_" + str(i), param_out))
master_params.append(("SubMasterParam_" + str(i), master_param))
master_param_outs.append(
("SubMasterParamOut_" + str(i), master_param_out))
self.inputs = {
'Param': params,
'Grad': grads,
'Velocity': velocitys,
'LearningRate': learning_rates,
'MasterParam': master_params,
}
self.attrs = {
'mu': mu,
'lars_coeff': lars_coeff,
'lars_weight_decay': [lars_weight_decay],
'multi_precision': True,
'rescale_grad': rescale_grad
}
self.outputs = {
'ParamOut': param_out,
'VelocityOut': velocity_out,
'MasterParamOut': master_param_out
'ParamOut': param_outs,
'VelocityOut': velocity_outs,
'MasterParamOut': master_param_outs
}
def test_check_output(self):
......@@ -191,46 +211,65 @@ class TestLarsMomentumOpWithMP(OpTest):
if core.is_float16_supported(place):
self.check_output_with_place(place)
def config(self):
self.params_num = 1
class TestLarsMomentumOp(OpTest):
def setUp(self):
self.config()
self.op_type = "lars_momentum"
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
params = []
grads = []
velocitys = []
param_outs = []
velocity_outs = []
learning_rates = []
for i in range(self.params_num):
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
pnorm = np.sqrt(np.square(param).sum())
gnorm = np.sqrt(np.square(grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * param)
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay
* param)
param_out = param - velocity_out
params.append(("SubParam_" + str(i), param))
grads.append(("SubGrad_" + str(i), grad))
velocitys.append(("SubVelocity_" + str(i), velocity))
learning_rates.append(("SubLearning_rate_" + str(i), learning_rate))
velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out))
param_outs.append(("SubParam_out_" + str(i), param_out))
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate
'Param': params,
'Grad': grads,
'Velocity': velocitys,
'LearningRate': learning_rates
}
self.attrs = {
'mu': mu,
'lars_coeff': lars_coeff,
'lars_weight_decay': lars_weight_decay
'lars_weight_decay': [lars_weight_decay]
}
pnorm = np.sqrt(np.square(param).sum())
gnorm = np.sqrt(np.square(grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * param)
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay *
param)
param_out = param - velocity_out
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
self.outputs = {'ParamOut': param_outs, 'VelocityOut': velocity_outs}
def test_check_output(self):
paddle.enable_static()
self.check_output()
def config(self):
self.params_num = 1
class TestSparseMomentumOp(unittest.TestCase):
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册