未验证 提交 0c31579c 编写于 作者: L limingshu 提交者: GitHub

Merge lars op (#35476)

* A leap of try for cudaLaunchCooperativeKernel

* fix bugs

* Totally replace the lar cuda kernel

* Fix bugs

* a test for lars merge

* Adding las_op_momentum infer_shape

* Fix codes

* use avg_numel instead of max_numel to acquire grid num

* modify unittest files about lars op

* Finally converge when merged-lars works

* fix ctest files

* add merged_operation kernel when cuda version is older than 11

* Fix code style

* fix ctest failure

* fix error

* fix all ctest error and change lars compute code of cpu

* fix bugs on v100.

* revert python modififation about lars

* revert python modification codes
上级 24418479
...@@ -13,46 +13,158 @@ See the License for the specific language governing permissions and ...@@ -13,46 +13,158 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" #include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class LarsMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Velocity"), "Input", "Velocity",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("LearningRate"), "Input", "LearningRate",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("ParamOut"), "Output", "ParamOut",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"), "Output", "VelocityOut",
"LarsMomentum");
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->GetInputsVarType("Param").front()));
auto lr_dims = ctx->GetInputsDim("LearningRate");
auto grad_dim = ctx->GetInputsDim("Grad");
auto param_dim = ctx->GetInputsDim("Param");
auto velocity_dim = ctx->GetInputsDim("Velocity");
auto lars_weight_decays =
ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
auto multi_precision = ctx->Attrs().Get<bool>("multi_precision");
PADDLE_ENFORCE_EQ(
param_dim.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(), grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(), velocity_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(), velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decays.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decays.size(), grad_dim.size()));
if (multi_precision) {
OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam",
"LarsMomentumMultiPrecision");
OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"), "Output",
"MasterParamOut", "LarsMomentumMultiPrecision");
}
for (size_t i = 0; i < lr_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
framework::product(lr_dims[i])));
}
for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i],
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx->Inputs("Grad")[i].front(),
ctx->GetInputsVarType("Grad")[i]));
PADDLE_ENFORCE_EQ(
param_dim[i], grad_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i], grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i], velocity_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i], velocity_dim[i]));
}
ctx->SetOutputsDim("ParamOut", param_dim);
ctx->SetOutputsDim("VelocityOut", param_dim);
if (ctx->HasOutputs("MasterParamOut")) {
ctx->SetOutputsDim("MasterParamOut", param_dim);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker { class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Param", AddInput("Param",
"(LoDTensor, default LoDTensor<float>) " "(LoDTensor, default LoDTensor<float>) "
"Input parameter that has to be updated"); "Input parameter that has to be updated")
.AsDuplicable();
AddInput("Grad", AddInput("Grad",
"(LoDTensor, default LoDTensor<float>) " "(LoDTensor, default LoDTensor<float>) "
"Input gradient of the parameter"); "Input gradient of the parameter")
.AsDuplicable();
AddInput("Velocity", AddInput("Velocity",
"(LoDTensor, default LoDTensor<float>) " "(LoDTensor, default LoDTensor<float>) "
"Input velocity (corresponding to the parameter) " "Input velocity (corresponding to the parameter) "
"that has to be updated"); "that has to be updated")
.AsDuplicable();
AddInput("LearningRate", AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) " "(LoDTensor, default LoDTensor<float>) "
"Input learning rate"); "Input learning rate")
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); .AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDuplicable()
.AsDispensable();
AddOutput("ParamOut", AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. " "(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param)."); "It shared memory with Input(Param).")
.AsDuplicable();
AddOutput("VelocityOut", AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. " "(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity)."); "It shared memory with Input(Velocity).")
.AsDuplicable();
AddOutput("MasterParamOut", AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. " "The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).") "It shared memory with Input(MasterParam).")
.AsDuplicable()
.AsDispensable(); .AsDispensable();
AddAttr<float>("mu", "(float) Momentum coefficient"); AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.") AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
.SetDefault(0.001); .SetDefault(0.001);
AddAttr<float>("lars_weight_decay", AddAttr<std::vector<float>>(
"(float, default 0.0005) LARS weight decay") "lars_weight_decay",
.SetDefault(0.0005); "(std::vector<float>, default 0.0005) LARS weight decay params")
.SetDefault({0.0005});
AddAttr<float>("epsilon", AddAttr<float>("epsilon",
"(float, default 0.0) epsilon to avoid Division by Zero.") "(float, default 0.0) epsilon to avoid Division by Zero.")
.SetDefault(0.0); .SetDefault(0.0);
...@@ -96,7 +208,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { ...@@ -96,7 +208,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker, lars_momentum, ops::LarsMomentumOp, ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LarsMomentumOpVarTypeInference); ops::LarsMomentumOpVarTypeInference);
......
...@@ -18,18 +18,8 @@ limitations under the License. */ ...@@ -18,18 +18,8 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" #include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
#if defined(__NVCC__) && CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
/* Once CUDA_VERSION is beyond 11.0, cooperative_groups can be involved in
without adding --rdc=true compile flag, then L2_norm cuda kernel can be
set as a __device__ kernel rather than global kernel. On the contrary,
the compile flag shall be set in old version, which may affect the cuda
kernel performance in paddle, consequently, L2_norm kernel shall be set
as a __global__ kernel.
*/
#include <cooperative_groups.h> #include <cooperative_groups.h>
#define LARS_FUNCTION_FLAG __device__
#else
#define LARS_FUNCTION_FLAG __global__
#endif #endif
#ifdef __HIPCC__ #ifdef __HIPCC__
...@@ -38,6 +28,8 @@ limitations under the License. */ ...@@ -38,6 +28,8 @@ limitations under the License. */
#define LARS_BLOCK_SIZE 512 #define LARS_BLOCK_SIZE 512
#endif #endif
#define LARS_MAX_MERGED_OPS 150
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,6 +45,43 @@ __device__ __forceinline__ double Fma(double x, double y, double z) { ...@@ -53,6 +45,43 @@ __device__ __forceinline__ double Fma(double x, double y, double z) {
return fma(x, y, z); return fma(x, y, z);
} }
template <typename T>
class LarsThreadConfig {
public:
int grid_for_norm;
int grid_for_lars;
#if CUDA_VERSION >= 11000
private:
int grid_stride;
public:
explicit LarsThreadConfig(int64_t numel, int sm_num, int num_blocks_per_sm) {
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
grid_for_lars =
std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE);
grid_stride = LARS_BLOCK_SIZE * grid_for_lars;
}
int GetRepeatTimes(int64_t numel) {
return (numel + grid_stride - 1) / grid_stride - 1;
}
#else
int repeat_times;
explicit LarsThreadConfig(const int64_t numel) {
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
grid_for_norm = std::min(grid, LARS_BLOCK_SIZE);
const int grid_stride = grid_for_norm * LARS_BLOCK_SIZE;
repeat_times = (numel + grid_stride - 1) / grid_stride - 1;
// Determine to read 4 fp16 or float data once, but 2 double data once.
grid_for_lars =
std::is_same<double, T>::value
? (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1)
: (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2);
}
#endif
};
template <typename T, typename MT, int VecSize, bool IsAmp = false> template <typename T, typename MT, int VecSize, bool IsAmp = false>
__device__ inline void VectorizeLarsUpdate( __device__ inline void VectorizeLarsUpdate(
const T* __restrict__ grad, const MT* __restrict__ param, const T* __restrict__ grad, const MT* __restrict__ param,
...@@ -85,7 +114,6 @@ __device__ inline void VectorizeLarsUpdate( ...@@ -85,7 +114,6 @@ __device__ inline void VectorizeLarsUpdate(
VecType grad_data = grad_vec[i]; VecType grad_data = grad_vec[i];
VecMType param_data = param_vec[i]; VecMType param_data = param_vec[i];
VecMType velocity_data = velocity_vec[i]; VecMType velocity_data = velocity_vec[i];
#pragma unroll #pragma unroll
for (int j = 0; j < VecSize; ++j) { for (int j = 0; j < VecSize; ++j) {
MT grad_val = static_cast<MT>(grad_data[j]) * rescale_grad; MT grad_val = static_cast<MT>(grad_data[j]) * rescale_grad;
...@@ -116,41 +144,49 @@ __device__ inline void VectorizeLarsUpdate( ...@@ -116,41 +144,49 @@ __device__ inline void VectorizeLarsUpdate(
} }
} }
#if CUDA_VERSION >= 11000
/* Once CUDA_VERSION is beyond 11, cooperative_groups can be involved in without
--rdc=true compile flag, then L2_norm kernel can be set with __device__ and
cooperative_groups::grid_group also can be involved. Otherwise, adding this
flag may affect much, L2_norm kernel shall be set with __global__.*/
// TODO(limingshu): declaration of cooperative_groups wapper is invalid in host.
template <typename T, typename MT>
__forceinline__ __device__ void L2NormKernel(
const cooperative_groups::grid_group* cg,
#else
template <typename T, typename MT> template <typename T, typename MT>
LARS_FUNCTION_FLAG void L2NormKernel( __global__ void L2NormKernel(
#endif
const T* __restrict__ p_data, const T* __restrict__ g_data, const T* __restrict__ p_data, const T* __restrict__ g_data,
MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, MT* __restrict__ p_buffer, MT* __restrict__ g_buffer, const int64_t numel,
const int repeat_times, const int64_t numel, const MT rescale_grad, const int repeat_times, const MT rescale_grad, const int thresh = 0,
MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) { MT* __restrict__ p_n = nullptr, MT* __restrict__ g_n = nullptr) {
__shared__ MT s_buffer[2];
int tid = threadIdx.x + blockDim.x * blockIdx.x; int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x; int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
const MT rescale_grad_pow = rescale_grad * rescale_grad; const MT rescale_pow = rescale_grad * rescale_grad;
__shared__ MT s_buffer[2];
s_buffer[0] = static_cast<MT>(0); s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0); s_buffer[1] = static_cast<MT>(0);
MT p_tmp_val = static_cast<MT>(0); MT p_tmp = static_cast<MT>(0);
MT g_tmp_val = static_cast<MT>(0); MT g_tmp = static_cast<MT>(0);
if (repeat_times == 0) { if (repeat_times == 0) {
if (tid < numel) { if (tid < numel) {
p_tmp_val = static_cast<MT>(p_data[tid]); p_tmp = static_cast<MT>(p_data[tid]);
g_tmp_val = static_cast<MT>(g_data[tid]); g_tmp = static_cast<MT>(g_data[tid]);
} }
s_buffer[0] += math::blockReduceSum<MT>(p_tmp_val * p_tmp_val, FINAL_MASK); s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp_val * g_tmp_val, FINAL_MASK); s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
} else { } else {
/* To avoid occupy too much temp buffer. Hence, slice the whole data into 2 /* Avoid occupy too much temp buffer. Slice the whole data into 2 parts,
parts, the front of them whose quantity is excatly multiple of grid-thread the front of data whose quantity is excatly multiple of grid-thread
number, and this part of data is delt in for loop, the rest of data is delt number, and delt in for loop, the rest is delt with another step. */
with another step to avoid visiting data address beyond bound. */
for (int i = 0; i < repeat_times; ++i) { for (int i = 0; i < repeat_times; ++i) {
p_tmp_val = static_cast<MT>(p_data[tid]); p_tmp = static_cast<MT>(p_data[tid]);
g_tmp_val = static_cast<MT>(g_data[tid]); g_tmp = static_cast<MT>(g_data[tid]);
tid += grid_stride; tid += grid_stride;
s_buffer[0] += s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
math::blockReduceSum<MT>(p_tmp_val * p_tmp_val, FINAL_MASK); s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
s_buffer[1] +=
math::blockReduceSum<MT>(g_tmp_val * g_tmp_val, FINAL_MASK);
__syncthreads(); __syncthreads();
} }
MT p_val = 0; MT p_val = 0;
...@@ -168,69 +204,46 @@ LARS_FUNCTION_FLAG void L2NormKernel( ...@@ -168,69 +204,46 @@ LARS_FUNCTION_FLAG void L2NormKernel(
p_buffer[blockIdx.x] = s_buffer[0]; p_buffer[blockIdx.x] = s_buffer[0];
g_buffer[blockIdx.x] = s_buffer[1]; g_buffer[blockIdx.x] = s_buffer[1];
} }
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
// Grid sync for completely writring partial result back to gloabl memory cg->sync(); // Grid sync for writring partial result to gloabl memory
const cooperative_groups::grid_group cg = cooperative_groups::this_grid(); MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
cg.sync(); MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
MT p_partial_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0; *p_n = Sqrt(math::blockReduceSum<MT>(p_part_sum, FINAL_MASK));
MT g_partial_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0; *g_n = Sqrt(rescale_pow * math::blockReduceSum<MT>(g_part_sum, FINAL_MASK));
*p_n = Sqrt(math::blockReduceSum<MT>(p_partial_sum, FINAL_MASK));
*g_n = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(g_partial_sum, FINAL_MASK));
#endif #endif
} }
template <typename T, typename MT> template <typename T, typename MT>
__global__ void MomentumLarsKernel( __forceinline__ __device__ void MomentumUpdate(
const T* __restrict__ param, const T* __restrict__ grad, const T* __restrict__ param, const T* __restrict__ grad,
const MT* __restrict__ velocity, T* param_out, MT* velocity_out, const MT* __restrict__ velocity, T* param_out, MT* velocity_out,
const MT* __restrict__ master_param, MT* __restrict__ master_param_out, const MT* __restrict__ master_param, MT* __restrict__ master_param_out,
const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer, const MT* __restrict__ learning_rate, const MT mu,
MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff, const MT lars_weight_decay, const MT lars_coeff, const MT epsilon,
const MT lars_weight_decay, const MT epsilon, const MT rescale_grad, const MT rescale_grad, const MT param_norm, const MT grad_norm,
const int repeat_times, const int thresh, const int64_t numel) { const int tid, const int grid_stride, const int64_t numel,
int tid = threadIdx.x + blockIdx.x * blockDim.x; const bool is_amp) {
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
#if CUDA_VERSION >= 11000
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(param, grad, p_buffer, g_buffer, repeat_times, numel,
rescale_grad, &param_norm, &grad_norm);
#else
const MT rescale_grad_pow = rescale_grad * rescale_grad;
MT param_parital_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT grad_parital_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
__syncthreads();
MT param_norm =
Sqrt(math::blockReduceSum<MT>(param_parital_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(grad_parital_norm, FINAL_MASK));
#endif
const MT lr = learning_rate[0]; const MT lr = learning_rate[0];
MT local_lr = lr; MT local_lr = lr;
if (lars_weight_decay > static_cast<MT>(0)) { if (lars_weight_decay > static_cast<MT>(0)) {
local_lr = lr * lars_coeff * param_norm / local_lr = lr * lars_coeff * param_norm /
(Fma(lars_weight_decay, param_norm, grad_norm) + epsilon); (fma(lars_weight_decay, param_norm, grad_norm) + epsilon);
} }
if (is_amp) {
if (master_param_out) { VectorizeLarsUpdate<T, MT, /*VecSize=*/4, /*IsAmp=*/true>(
VectorizeLarsUpdate<T, MT, 4, true>(grad, master_param, velocity, param_out, grad, master_param, velocity, param_out, velocity_out, mu, local_lr,
velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, grid_stride, numel,
lars_weight_decay, rescale_grad, tid, master_param_out);
grid_stride, numel, master_param_out);
} else { } else {
if (std::is_same<T, float>::value || if (std::is_same<T, float>::value ||
std::is_same<T, paddle::platform::float16>::value) { std::is_same<T, paddle::platform::float16>::value) {
// As for multiple-precision, type T and MT cannot be more than fp16 or /* TODO(limingshu): pointer cast may damage memory accessing for fp16 */
// fp32, Then, the maximum data IO size could be set to 4. VectorizeLarsUpdate<T, MT, /*VecSize=*/4, /*IsAmp=*/false>(
VectorizeLarsUpdate<T, MT, 4, false>(
grad, reinterpret_cast<const MT*>(param), velocity, param_out, grad, reinterpret_cast<const MT*>(param), velocity, param_out,
velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid,
grid_stride, numel); grid_stride, numel);
} else { } else {
VectorizeLarsUpdate<T, MT, 2, false>( VectorizeLarsUpdate<T, MT, /*VecSize=*/2, /*IsAmp=*/false>(
grad, reinterpret_cast<const MT*>(param), velocity, param_out, grad, reinterpret_cast<const MT*>(param), velocity, param_out,
velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid, velocity_out, mu, local_lr, lars_weight_decay, rescale_grad, tid,
grid_stride, numel); grid_stride, numel);
...@@ -238,94 +251,235 @@ __global__ void MomentumLarsKernel( ...@@ -238,94 +251,235 @@ __global__ void MomentumLarsKernel(
} }
} }
#if CUDA_VERSION >= 11000
template <typename T, typename MT>
struct LarsParamWarpper {
int64_t numel_arr[LARS_MAX_MERGED_OPS];
int repeat_arr[LARS_MAX_MERGED_OPS];
const T* __restrict__ p_arr[LARS_MAX_MERGED_OPS];
const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ v_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ lr_arr[LARS_MAX_MERGED_OPS];
const MT* __restrict__ master_p_arr[LARS_MAX_MERGED_OPS];
T* __restrict__ p_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ v_out_arr[LARS_MAX_MERGED_OPS];
MT* __restrict__ master_p_out_arr[LARS_MAX_MERGED_OPS];
MT weight_decay_arr[LARS_MAX_MERGED_OPS];
};
template <typename T, typename MT>
__global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT>* lars_warpper,
MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer,
const int op_num, const MT mu,
const MT lars_coeff, const MT epsilon,
const MT rescale_grad,
const bool is_amp) {
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
for (int i = 0; i < op_num; ++i) {
int numel = lars_warpper->numel_arr[i];
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, lars_warpper->p_arr[i], lars_warpper->g_arr[i],
p_buffer, g_buffer, numel, lars_warpper->repeat_arr[i],
rescale_grad, 0, &param_norm, &grad_norm);
MomentumUpdate<T, MT>(
lars_warpper->p_arr[i], lars_warpper->g_arr[i],
lars_warpper->v_out_arr[i], lars_warpper->p_out_arr[i],
lars_warpper->v_out_arr[i], lars_warpper->master_p_arr[i],
lars_warpper->master_p_out_arr[i], lars_warpper->lr_arr[i], mu,
lars_warpper->weight_decay_arr[i], lars_coeff, epsilon, rescale_grad,
param_norm, grad_norm, tid, grid_stride, numel, is_amp);
}
}
#endif
template <typename T, typename MT>
__global__ void MomentumLarsKernel(
const T* __restrict__ param, const T* __restrict__ grad,
const MT* __restrict__ velocity, T* param_out, MT* velocity_out,
const MT* __restrict__ master_param, MT* __restrict__ master_param_out,
const MT* __restrict__ learning_rate, MT* __restrict__ p_buffer,
MT* __restrict__ g_buffer, const MT mu, const MT lars_coeff,
const MT lars_weight_decay, const MT epsilon, const MT rescale_grad,
const int repeat_times, const int thresh, const int64_t numel,
const bool is_amp) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int grid_stride = gridDim.x * LARS_BLOCK_SIZE;
#if CUDA_VERSION >= 11000
const cooperative_groups::grid_group cg = cooperative_groups::this_grid();
MT param_norm = static_cast<MT>(0);
MT grad_norm = static_cast<MT>(0);
L2NormKernel<T, MT>(&cg, param, grad, p_buffer, g_buffer, numel, repeat_times,
rescale_grad, gridDim.x, &param_norm, &grad_norm);
#else
const MT rescale_grad_pow = rescale_grad * rescale_grad;
MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
__syncthreads();
MT param_norm = Sqrt(math::blockReduceSum<MT>(param_part_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow *
math::blockReduceSum<MT>(grad_part_norm, FINAL_MASK));
#endif
MomentumUpdate<T, MT>(param, grad, velocity, param_out, velocity_out,
master_param, master_param_out, learning_rate, mu,
lars_weight_decay, lars_coeff, epsilon, rescale_grad,
param_norm, grad_norm, tid, grid_stride, numel, is_amp);
}
template <typename T, typename MT>
inline void SeparatedLarsMomentumOpCUDAKernel(
const platform::CUDADeviceContext& cuda_ctx, const T* param_data,
T* param_out_data, const MT* velocity_data, MT* velocity_out_data,
const T* grad_data, const MT* lr, MT* p_buffer, MT* g_buffer, const MT mu,
const MT lars_coeff, const MT weight_decay, const MT epsilon,
const MT rescale_grad, const int64_t numel, const MT* master_param_data,
MT* master_out_data, const bool is_amp) {
LarsThreadConfig<T> lars_thread_config(numel);
L2NormKernel<T, MT><<<lars_thread_config.grid_for_norm, LARS_BLOCK_SIZE, 0,
cuda_ctx.stream()>>>(
param_data, grad_data, p_buffer, g_buffer, numel,
lars_thread_config.repeat_times, rescale_grad);
MomentumLarsKernel<T, MT><<<lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE,
0, cuda_ctx.stream()>>>(
param_data, grad_data, velocity_data, param_out_data, velocity_out_data,
master_param_data, master_out_data, lr, p_buffer, g_buffer, mu,
lars_coeff, weight_decay, epsilon, rescale_grad, 0,
lars_thread_config.grid_for_norm, numel, is_amp);
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MT = MultiPrecisionType<T>; using MT = MultiPrecisionType<T>;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision"); int num_blocks_per_sm = 0;
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut"); bool multi_precision = ctx.Attr<bool>("multi_precision");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut"); auto& cuda_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto param = ctx.Input<framework::LoDTensor>("Param"); int sm_num = cuda_ctx.GetSMCount();
auto velocity = ctx.Input<framework::LoDTensor>("Velocity"); framework::Tensor tmp_buffer_t =
auto grad = ctx.Input<framework::LoDTensor>("Grad"); ctx.AllocateTmpTensor<MT, platform::CUDADeviceContext>(
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate"); {LARS_BLOCK_SIZE << 1}, cuda_ctx);
auto* p_buffer = tmp_buffer_t.mutable_data<MT>(ctx.GetPlace());
int64_t numel = param->numel(); auto* g_buffer = p_buffer + LARS_BLOCK_SIZE;
int grid = (numel + LARS_BLOCK_SIZE - 1) / LARS_BLOCK_SIZE;
const framework::Tensor* master_param = nullptr;
framework::Tensor* master_param_out = nullptr;
const MT* master_param_data = nullptr;
MT* master_param_out_data = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<framework::Tensor>("MasterParam");
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
master_param_data = master_param->data<MT>();
master_param_out_data =
master_param_out->mutable_data<MT>(ctx.GetPlace());
}
MT mu = static_cast<MT>(ctx.Attr<float>("mu")); MT mu = static_cast<MT>(ctx.Attr<float>("mu"));
MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff")); MT lars_coeff = static_cast<MT>(ctx.Attr<float>("lars_coeff"));
MT lars_weight_decay =
static_cast<MT>(ctx.Attr<float>("lars_weight_decay"));
MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon")); MT epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad")); MT rescale_grad = static_cast<MT>(ctx.Attr<float>("rescale_grad"));
auto* param_data = param->data<T>(); auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
auto* grad_data = grad->data<T>(); auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
auto* velocity_data = velocity->data<MT>(); auto param = ctx.MultiInput<framework::LoDTensor>("Param");
auto* lr = learning_rate->data<MT>(); auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
auto& cuda_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
T* param_out_data = param_out->mutable_data<T>(ctx.GetPlace()); auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
MT* velocity_out_data = velocity_out->mutable_data<MT>(ctx.GetPlace()); auto learning_rate = ctx.MultiInput<framework::LoDTensor>("LearningRate");
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
auto master_param_out =
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");
int op_num = grad.size();
#if CUDA_VERSION >= 11000 #if CUDA_VERSION >= 11000
/* if (op_num > 1) {
Once model trainning with lars optimizer, whose principal implementation LarsParamWarpper<T, MT> lars_warpper;
is achieved by following two steps: PADDLE_ENFORCE_LT(
op_num, LARS_MAX_MERGED_OPS,
platform::errors::InvalidArgument(
"The maximum number of merged-ops supported is (%d), but"
"lars op required for trainning this model is (%d)\n",
LARS_MAX_MERGED_OPS, op_num));
/* Implementation of lars optimizer consists of following two steps:
1. Figure out the L2 norm statistic result of grad data and param data. 1. Figure out the L2 norm statistic result of grad data and param data.
2. Update param and velocity data with usage of L2 norm statistic result. 2. Update param and velocity with usage of L2 norm statistic result.
Step1 and step2 can be merged with api provided by nvida
Orignally, these two steps were fulfilled by respective eigen function and cudaLaunchCooperativeKernel:
cuda kernel, however the overhead of eigen function occupied much ratio in
total, consequently affect the performance of lars op, make it necessary
to combine 2 steps into one cuda kernel.
Since the step1 is l2 norm statistic, grid level reduce is needed. To
achieve this and continuous calculation of step 2 in only one global
lanuch, essential basis is to control all grid-threads while running. Apart
from normal lanuch form, cuda9.0 provides `cudaLaunchCooperativeKernel`
api :
- The thread quantity shall less than pyhsical SM limited threads - The thread quantity shall less than pyhsical SM limited threads
- Launches a device function where thread blocks can cooperate and - Launche as thread-block can synchronizlly execute. */
synchronize as they execute. cudaOccupancyMaxActiveBlocksPerMultiprocessor(
*/ &num_blocks_per_sm, MergedMomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
sizeof(MT) << 1);
size_t total_numel = 0;
for (int i = 0; i < op_num; ++i) {
size_t temp_numel = param[i]->numel();
total_numel += temp_numel;
lars_warpper.numel_arr[i] = temp_numel;
lars_warpper.p_arr[i] = param[i]->data<T>();
lars_warpper.g_arr[i] = grad[i]->data<T>();
lars_warpper.v_arr[i] = velocity[i]->data<MT>();
lars_warpper.lr_arr[i] = learning_rate[i]->data<MT>();
lars_warpper.p_out_arr[i] =
param_out[i]->mutable_data<T>(ctx.GetPlace());
lars_warpper.v_out_arr[i] =
velocity_out[i]->mutable_data<MT>(ctx.GetPlace());
lars_warpper.weight_decay_arr[i] = static_cast<MT>(weight_decay_arr[i]);
}
int64_t avg_numel = total_numel / op_num;
LarsThreadConfig<float> lars_thread_config(avg_numel, sm_num,
num_blocks_per_sm);
for (int i = 0; i < op_num; ++i) {
lars_warpper.repeat_arr[i] =
lars_thread_config.GetRepeatTimes(lars_warpper.numel_arr[i]);
}
if (multi_precision) {
for (int i = 0; i < op_num; ++i) {
lars_warpper.master_p_arr[i] = master_param[i]->data<MT>();
lars_warpper.master_p_out_arr[i] =
master_param_out[i]->mutable_data<MT>(ctx.GetPlace());
}
}
auto merged_buf = memory::Alloc(cuda_ctx, sizeof(lars_warpper));
auto* merged_ptr =
reinterpret_cast<LarsParamWarpper<T, MT>*>(merged_buf->ptr());
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, cuda_ctx.GetPlace()),
reinterpret_cast<void*>(merged_ptr), platform::CPUPlace(),
reinterpret_cast<void*>(&lars_warpper), sizeof(lars_warpper),
cuda_ctx.stream());
void* cuda_param[] = {reinterpret_cast<void*>(&merged_ptr),
reinterpret_cast<void*>(&p_buffer),
reinterpret_cast<void*>(&g_buffer),
reinterpret_cast<void*>(&op_num),
reinterpret_cast<void*>(&mu),
reinterpret_cast<void*>(&lars_coeff),
reinterpret_cast<void*>(&epsilon),
reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads, and thead of each block synchronizedly cooperate.
cudaLaunchCooperativeKernel(
reinterpret_cast<void*>(MergedMomentumLarsKernel<T, MT>),
lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0,
cuda_ctx.stream());
} else {
auto* param_data = param[0]->data<T>();
auto* grad_data = grad[0]->data<T>();
auto* velocity_data = velocity[0]->data<MT>();
auto* lr = learning_rate[0]->data<MT>();
auto* param_out_data = param_out[0]->mutable_data<T>(ctx.GetPlace());
auto* velocity_out_data =
velocity_out[0]->mutable_data<MT>(ctx.GetPlace());
const MT* master_param_data =
multi_precision ? master_param[0]->data<MT>() : nullptr;
MT* master_param_out_data =
multi_precision
? master_param_out[0]->mutable_data<MT>(ctx.GetPlace())
: nullptr;
int64_t numel = param[0]->numel();
MT lars_weight_decay = weight_decay_arr[0];
// Figure out how many blocks can be active in each sm. // Figure out how many blocks can be active in each sm.
int num_blocks_per_sm = 0; cudaOccupancyMaxActiveBlocksPerMultiprocessor(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, &num_blocks_per_sm, MomentumLarsKernel<T, MT>, LARS_BLOCK_SIZE,
MomentumLarsKernel<T, MT>, sizeof(MT) << 1);
LARS_BLOCK_SIZE, sizeof(MT)); LarsThreadConfig<float> lars_thread_config(numel, sm_num,
int sm_num = cuda_ctx.GetSMCount(); num_blocks_per_sm);
int grid_real = int repeat_times = lars_thread_config.GetRepeatTimes(numel);
std::min(std::min(sm_num * num_blocks_per_sm, grid), LARS_BLOCK_SIZE);
framework::Tensor tmp_buffer_t =
ctx.AllocateTmpTensor<MT, platform::CUDADeviceContext>(
{LARS_BLOCK_SIZE << 1}, cuda_ctx);
auto* p_buffer = tmp_buffer_t.mutable_data<MT>(ctx.GetPlace());
auto* g_buffer = p_buffer + LARS_BLOCK_SIZE;
int grid_stride = LARS_BLOCK_SIZE * grid;
int repeat_times = (numel + grid_stride - 1) / grid_stride - 1;
int thresh = 0; int thresh = 0;
// Uniform kernel parameter for cudaLaunchCooperativeKernel
void* cuda_param[] = { void* cuda_param[] = {
reinterpret_cast<void*>(&param_data), reinterpret_cast<void*>(&param_data),
reinterpret_cast<void*>(&grad_data), reinterpret_cast<void*>(&grad_data),
...@@ -344,38 +498,31 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { ...@@ -344,38 +498,31 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
reinterpret_cast<void*>(&rescale_grad), reinterpret_cast<void*>(&rescale_grad),
reinterpret_cast<void*>(&repeat_times), reinterpret_cast<void*>(&repeat_times),
reinterpret_cast<void*>(&thresh), // Just a placeholder reinterpret_cast<void*>(&thresh), // Just a placeholder
reinterpret_cast<void*>(&numel)}; reinterpret_cast<void*>(&numel),
reinterpret_cast<void*>(&multi_precision)};
// Lanuch all sm theads. // Lanuch all sm theads.
cudaLaunchCooperativeKernel( cudaLaunchCooperativeKernel(
reinterpret_cast<void*>(MomentumLarsKernel<T, MT>), grid_real, reinterpret_cast<void*>(MomentumLarsKernel<T, MT>),
LARS_BLOCK_SIZE, cuda_param, 0, cuda_ctx.stream()); lars_thread_config.grid_for_lars, LARS_BLOCK_SIZE, cuda_param, 0,
cuda_ctx.stream());
}
#else #else
// Determine to read 4 fp16 or float data once, but 2 double data once. for (int i = 0; i < op_num; ++i) {
int grid_lars = const MT* master_param_data =
sizeof(T) < sizeof(double) multi_precision ? master_param[i]->data<MT>() : nullptr;
? (numel + (LARS_BLOCK_SIZE << 2) - 1) / (LARS_BLOCK_SIZE << 2) MT* master_param_out_data =
: (numel + (LARS_BLOCK_SIZE << 1) - 1) / (LARS_BLOCK_SIZE << 1); multi_precision
? master_param_out[i]->mutable_data<MT>(ctx.GetPlace())
int grid_norm = std::min(grid, LARS_BLOCK_SIZE); : nullptr;
framework::Tensor p_buffer_t = SeparatedLarsMomentumOpCUDAKernel<T, MT>(
ctx.AllocateTmpTensor<MT, platform::CUDADeviceContext>( cuda_ctx, param[i]->data<T>(),
{LARS_BLOCK_SIZE << 1}, cuda_ctx); param_out[i]->mutable_data<T>(ctx.GetPlace()),
auto* p_buffer = p_buffer_t.mutable_data<MT>(ctx.GetPlace()); velocity[i]->data<MT>(),
auto* g_buffer = p_buffer + LARS_BLOCK_SIZE; velocity_out[i]->mutable_data<MT>(ctx.GetPlace()), grad[i]->data<T>(),
learning_rate[i]->data<MT>(), p_buffer, g_buffer, mu, lars_coeff,
const int grid_stride = LARS_BLOCK_SIZE * grid_norm; weight_decay_arr[i], epsilon, rescale_grad, param[i]->numel(),
const int repeat_times = (numel + grid_stride - 1) / grid_stride - 1; master_param_data, master_param_out_data, multi_precision);
}
L2NormKernel<T, MT><<<grid_norm, LARS_BLOCK_SIZE, 0, cuda_ctx.stream()>>>(
param_data, grad_data, p_buffer, g_buffer, repeat_times, numel,
rescale_grad);
MomentumLarsKernel<
T, MT><<<grid_lars, LARS_BLOCK_SIZE, 0, cuda_ctx.stream()>>>(
param_data, grad_data, velocity_data, param_out_data, velocity_out_data,
master_param_data, master_param_out_data, lr, p_buffer, g_buffer, mu,
lars_coeff, lars_weight_decay, epsilon, rescale_grad, 0, grid_norm,
numel); // 0 is just a placeholder.
#endif #endif
} }
}; };
......
...@@ -23,36 +23,29 @@ template <typename T> ...@@ -23,36 +23,29 @@ template <typename T>
class LarsMomentumOpKernel : public framework::OpKernel<T> { class LarsMomentumOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut"); auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut"); auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param"); auto param = ctx.MultiInput<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity"); auto velocity = ctx.MultiInput<framework::LoDTensor>("Velocity");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate"); auto learning_rate = ctx.MultiInput<framework::LoDTensor>("LearningRate");
auto* grad_var = ctx.InputVar("Grad"); auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
// only support dense for now. auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())));
auto grad = ctx.Input<framework::LoDTensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
T mu = static_cast<T>(ctx.Attr<float>("mu")); T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff"); T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
T epsilon = ctx.Attr<float>("epsilon"); T epsilon = ctx.Attr<float>("epsilon");
auto p_out = framework::EigenVector<T>::Flatten(*param_out); int op_num = param.size();
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out); for (int i = 0; i < op_num; ++i) {
auto* lr = learning_rate[i]->data<T>();
T lars_weight_decay = weight_decay_arr[i];
param_out[i]->mutable_data<T>(ctx.GetPlace());
velocity_out[i]->mutable_data<T>(ctx.GetPlace());
auto p = framework::EigenVector<T>::Flatten(*param); auto p_out = framework::EigenVector<T>::Flatten(*(param_out[i]));
auto v = framework::EigenVector<T>::Flatten(*velocity); auto v_out = framework::EigenVector<T>::Flatten(*(velocity_out[i]));
auto g = framework::EigenVector<T>::Flatten(*grad); auto p = framework::EigenVector<T>::Flatten(*(param[i]));
auto* lr = learning_rate->data<T>(); auto v = framework::EigenVector<T>::Flatten(*(velocity[i]));
auto g = framework::EigenVector<T>::Flatten(*(grad[i]));
framework::Tensor p_norm_t, g_norm_t; framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1}); p_norm_t.Resize({1});
...@@ -61,9 +54,9 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> { ...@@ -61,9 +54,9 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
g_norm_t.mutable_data<T>(ctx.GetPlace()); g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t); auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t); auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
ep_norm = p.square().sum().sqrt(); ep_norm = p.square().sum().sqrt();
eg_norm = g.square().sum().sqrt(); eg_norm = g.square().sum().sqrt();
T local_lr = lr[0]; T local_lr = lr[0];
if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) { if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) {
local_lr = lr[0] * lars_coeff * ep_norm(0) / local_lr = lr[0] * lars_coeff * ep_norm(0) /
...@@ -72,6 +65,7 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> { ...@@ -72,6 +65,7 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
v_out = v * mu + local_lr * (g + lars_weight_decay * p); v_out = v * mu + local_lr * (g + lars_weight_decay * p);
p_out = p - v_out; p_out = p - v_out;
} }
}
}; };
} // namespace operators } // namespace operators
......
...@@ -2066,7 +2066,7 @@ class LarsMomentumOptimizer(Optimizer): ...@@ -2066,7 +2066,7 @@ class LarsMomentumOptimizer(Optimizer):
attrs = { attrs = {
"mu": self._momentum, "mu": self._momentum,
"lars_coeff": self._lars_coeff, "lars_coeff": self._lars_coeff,
"lars_weight_decay": _lars_weight_decay, "lars_weight_decay": [_lars_weight_decay],
"multi_precision": find_master, "multi_precision": find_master,
"rescale_grad": self._rescale_grad "rescale_grad": self._rescale_grad
} }
......
...@@ -103,7 +103,7 @@ class TestFleetLarsMetaOptimizer(unittest.TestCase): ...@@ -103,7 +103,7 @@ class TestFleetLarsMetaOptimizer(unittest.TestCase):
'op_role_var')[0] or ".b" in op.attr('op_role_var')[0]) 'op_role_var')[0] or ".b" in op.attr('op_role_var')[0])
] ]
for op in ops_without_wd: for op in ops_without_wd:
self.assertEqual(op.attr('lars_weight_decay'), 0) self.assertEqual(op.attr('lars_weight_decay')[0], 0)
def test_lars_apply_with_amp(self): def test_lars_apply_with_amp(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
......
...@@ -138,33 +138,27 @@ class TestMomentumOp2(OpTest): ...@@ -138,33 +138,27 @@ class TestMomentumOp2(OpTest):
"core is not compiled with CUDA") "core is not compiled with CUDA")
class TestLarsMomentumOpWithMP(OpTest): class TestLarsMomentumOpWithMP(OpTest):
def setUp(self): def setUp(self):
self.config()
self.op_type = "lars_momentum" self.op_type = "lars_momentum"
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
rescale_grad = 1.0
params = []
grads = []
velocitys = []
learning_rates = []
master_params = []
param_outs = []
velocity_outs = []
master_param_outs = []
for i in range(self.params_num):
master_param = np.random.random((123, 321)).astype("float32") master_param = np.random.random((123, 321)).astype("float32")
param = master_param.astype("float16") param = master_param.astype("float16")
grad = np.random.random((123, 321)).astype("float16") grad = np.random.random((123, 321)).astype("float16")
velocity = np.zeros((123, 321)).astype("float32") velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32") learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
rescale_grad = 1.0
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate,
'MasterParam': master_param,
}
self.attrs = {
'mu': mu,
'lars_coeff': lars_coeff,
'lars_weight_decay': lars_weight_decay,
'multi_precision': True,
'rescale_grad': rescale_grad
}
fp32_grad = grad.astype("float32") fp32_grad = grad.astype("float32")
pnorm = np.sqrt(np.square(master_param).sum()) pnorm = np.sqrt(np.square(master_param).sum())
...@@ -172,16 +166,42 @@ class TestLarsMomentumOpWithMP(OpTest): ...@@ -172,16 +166,42 @@ class TestLarsMomentumOpWithMP(OpTest):
local_lr = learning_rate * lars_coeff * pnorm / ( local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * pnorm) gnorm + lars_weight_decay * pnorm)
fp32_grad = fp32_grad * rescale_grad fp32_grad = fp32_grad * rescale_grad
velocity_out = mu * velocity + local_lr * (fp32_grad + lars_weight_decay velocity_out = mu * velocity + local_lr * (
* master_param) fp32_grad + lars_weight_decay * master_param)
p_new = master_param - velocity_out p_new = master_param - velocity_out
param_out = p_new.astype("float16") param_out = p_new.astype("float16")
master_param_out = p_new master_param_out = p_new
params.append(("SubParam_" + str(i), param))
grads.append(("SubGrad_" + str(i), grad))
velocitys.append(("SubVelocity_" + str(i), velocity))
learning_rates.append(("SubLearning_rate_" + str(i), learning_rate))
velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out))
param_outs.append(("SubParam_out_" + str(i), param_out))
master_params.append(("SubMasterParam_" + str(i), master_param))
master_param_outs.append(
("SubMasterParamOut_" + str(i), master_param_out))
self.inputs = {
'Param': params,
'Grad': grads,
'Velocity': velocitys,
'LearningRate': learning_rates,
'MasterParam': master_params,
}
self.attrs = {
'mu': mu,
'lars_coeff': lars_coeff,
'lars_weight_decay': [lars_weight_decay],
'multi_precision': True,
'rescale_grad': rescale_grad
}
self.outputs = { self.outputs = {
'ParamOut': param_out, 'ParamOut': param_outs,
'VelocityOut': velocity_out, 'VelocityOut': velocity_outs,
'MasterParamOut': master_param_out 'MasterParamOut': master_param_outs
} }
def test_check_output(self): def test_check_output(self):
...@@ -191,46 +211,65 @@ class TestLarsMomentumOpWithMP(OpTest): ...@@ -191,46 +211,65 @@ class TestLarsMomentumOpWithMP(OpTest):
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_output_with_place(place) self.check_output_with_place(place)
def config(self):
self.params_num = 1
class TestLarsMomentumOp(OpTest): class TestLarsMomentumOp(OpTest):
def setUp(self): def setUp(self):
self.config()
self.op_type = "lars_momentum" self.op_type = "lars_momentum"
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
params = []
grads = []
velocitys = []
param_outs = []
velocity_outs = []
learning_rates = []
for i in range(self.params_num):
param = np.random.random((123, 321)).astype("float32") param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32") grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32") velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32") learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001 pnorm = np.sqrt(np.square(param).sum())
lars_coeff = 0.001 gnorm = np.sqrt(np.square(grad).sum())
lars_weight_decay = 0.0005 local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * param)
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay
* param)
param_out = param - velocity_out
params.append(("SubParam_" + str(i), param))
grads.append(("SubGrad_" + str(i), grad))
velocitys.append(("SubVelocity_" + str(i), velocity))
learning_rates.append(("SubLearning_rate_" + str(i), learning_rate))
velocity_outs.append(("SubVelocity_out_" + str(i), velocity_out))
param_outs.append(("SubParam_out_" + str(i), param_out))
self.inputs = { self.inputs = {
'Param': param, 'Param': params,
'Grad': grad, 'Grad': grads,
'Velocity': velocity, 'Velocity': velocitys,
'LearningRate': learning_rate 'LearningRate': learning_rates
} }
self.attrs = { self.attrs = {
'mu': mu, 'mu': mu,
'lars_coeff': lars_coeff, 'lars_coeff': lars_coeff,
'lars_weight_decay': lars_weight_decay 'lars_weight_decay': [lars_weight_decay]
} }
self.outputs = {'ParamOut': param_outs, 'VelocityOut': velocity_outs}
pnorm = np.sqrt(np.square(param).sum())
gnorm = np.sqrt(np.square(grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * param)
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay *
param)
param_out = param - velocity_out
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def test_check_output(self): def test_check_output(self):
paddle.enable_static() paddle.enable_static()
self.check_output() self.check_output()
def config(self):
self.params_num = 1
class TestSparseMomentumOp(unittest.TestCase): class TestSparseMomentumOp(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册