未验证 提交 e106901e 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

support sparse of adam, *test=kunlun (#38483)

* support sparse of adam, *test=kunlun

* add pre-commit-config.yaml

* support sparse of adam in KL2,*test=kunlun

* support sparse of adam in KL2, *test=kunlun

* modify xpu.cmake, *test=kunlun

* support sparse of adam, rm some wait, *test=kunlun

* support sparse of adam, rm some wait, *test=kunlun

* support sparse of adam, *test=kunlun

* support sparse of adam, *test=kunlun

* support sparse of adam, *test=kunlun

* support sparse of adam, *test=kunlun

* support sparse of adam, *test=kunlun
上级 c3796061
...@@ -36,7 +36,7 @@ ENDIF() ...@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220104") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220116")
else() else()
SET(XPU_BASE_URL "${XPU_BASE_URL}") SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
......
...@@ -477,6 +477,155 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -477,6 +477,155 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
} }
}; };
#ifdef PADDLE_WITH_XPU
template <typename T>
struct MergeAdd<platform::XPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::XPUDeviceContext& context,
const framework::SelectedRows& input,
const bool sorted_result = false) {
framework::SelectedRows out;
(*this)(context, input, &out, sorted_result);
return out;
}
void operator()(const platform::XPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* output,
const bool sorted_result = false) {
framework::Vector<int64_t> input_rows(input.rows());
if (input_rows.size() == 0) {
return;
}
framework::SelectedRows& out = *output;
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
int r =
xpu::constant<T>(context.x_context(), out.mutable_value()->data<T>(),
merge_rows.size() * input_width, static_cast<T>(0.f));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU constant op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
}
auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
int n = input_width;
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
auto r = xpu::add(context.x_context(), &input_data[i * input_width],
&out_data[out_i * input_width],
&out_data[out_i * input_width], n);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d %s], ", r,
XPUAPIErrorMsg[r]));
}
}
void operator()(const platform::XPUDeviceContext& context,
const std::vector<const framework::SelectedRows*>& inputs,
framework::SelectedRows* output,
const bool sorted_result = false) {
if (inputs.size() == 0) {
VLOG(3) << "no input! return";
return;
}
const framework::SelectedRows* has_value_input = nullptr;
for (auto* in : inputs) {
if (in->rows().size() > 0) {
has_value_input = in;
break;
}
}
if (has_value_input == nullptr) {
VLOG(3) << "no input has value! just return" << std::endl;
return;
}
auto input_width = has_value_input->value().dims()[1];
auto input_height = has_value_input->height();
framework::SelectedRows& out = *output;
std::set<int64_t> merged_row_set;
size_t row_num = 0;
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1],
platform::errors::InvalidArgument(
"All inputs should have same "
"dimension except for the first one."));
PADDLE_ENFORCE_EQ(input_height, input->height(),
platform::errors::InvalidArgument(
"All inputs should have same height."));
row_num += input->rows().size();
merged_row_set.insert(input->rows().begin(), input->rows().end());
}
std::vector<int64_t> merge_rows(merged_row_set.begin(),
merged_row_set.end());
if (sorted_result) {
std::sort(merge_rows.begin(), merge_rows.end());
}
out.set_rows(merge_rows);
out.set_height(input_height);
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merged_row_set.size()), input_width}),
context.GetPlace());
int r =
xpu::constant<T>(context.x_context(), out.mutable_value()->data<T>(),
merge_rows.size() * input_width, static_cast<T>(0.f));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU constant op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
float* out_data = reinterpret_cast<float*>(out.mutable_value()->data<T>());
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
}
for (auto* input : inputs) {
if (input->rows().size() == 0) {
continue;
}
auto& input_rows = input->rows();
int n = input_width;
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
auto r = xpu::add(
context.x_context(), input->value().data<T>() + i * input_width,
&out_data[out_i * input_width], &out_data[out_i * input_width], n);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d %s], ", r,
XPUAPIErrorMsg[r]));
}
}
}
};
#endif
template <typename T> template <typename T>
struct MergeAverage<platform::CPUDeviceContext, T> { struct MergeAverage<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context, framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
...@@ -589,6 +738,10 @@ template struct MergeAdd<platform::CPUDeviceContext, ...@@ -589,6 +738,10 @@ template struct MergeAdd<platform::CPUDeviceContext,
template struct MergeAdd<platform::CPUDeviceContext, template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::bfloat16>; paddle::platform::bfloat16>;
#ifdef PADDLE_WITH_XPU
template struct MergeAdd<platform::XPUDeviceContext, float>;
#endif
template struct MergeAverage<platform::CPUDeviceContext, int>; template struct MergeAverage<platform::CPUDeviceContext, int>;
template struct MergeAverage<platform::CPUDeviceContext, int64_t>; template struct MergeAverage<platform::CPUDeviceContext, int64_t>;
template struct MergeAverage<platform::CPUDeviceContext, float>; template struct MergeAverage<platform::CPUDeviceContext, float>;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/adam_op.h" #include "paddle/fluid/operators/optimizers/adam_op.h"
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -155,6 +156,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> { ...@@ -155,6 +156,11 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
mom2_out.template mutable_data<float>(ctx.GetPlace()), mom2_out.template mutable_data<float>(ctx.GetPlace()),
param_out.template mutable_data<float>(ctx.GetPlace()), param_out.template mutable_data<float>(ctx.GetPlace()),
beta1, beta2, epsilon, param.numel()); beta1, beta2, epsilon, param.numel());
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU API return wrong value[%d],", r));
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
// update in cpu and then copy to xpu // update in cpu and then copy to xpu
if (beta1_pow.place() == platform::CPUPlace() && if (beta1_pow.place() == platform::CPUPlace() &&
...@@ -165,7 +171,6 @@ class AdamOpXPUKernel : public framework::OpKernel<T> { ...@@ -165,7 +171,6 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
const float* beta2_pow_p = beta2_pow.template data<float>(); const float* beta2_pow_p = beta2_pow.template data<float>();
beta2_pow_out->mutable_data<float>(platform::CPUPlace())[0] = beta2_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta2 * beta2_pow_p[0]; beta2 * beta2_pow_p[0];
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else { } else {
float* beta1_pow_out_p = float* beta1_pow_out_p =
beta1_pow_out->mutable_data<float>(ctx.GetPlace()); beta1_pow_out->mutable_data<float>(ctx.GetPlace());
...@@ -177,23 +182,129 @@ class AdamOpXPUKernel : public framework::OpKernel<T> { ...@@ -177,23 +182,129 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS, r, xpu::SUCCESS,
platform::errors::External( platform::errors::External(
"XPU kernel scale occur error in adamw error code ", r, "XPU kernel scale occur error in adam error code ", r,
XPUAPIErrorMsg[r])); XPUAPIErrorMsg[r]));
r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p, r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p,
beta2_pow.numel(), false, beta2, 0.0f); beta2_pow.numel(), false, beta2, 0.0f);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS, r, xpu::SUCCESS,
platform::errors::External( platform::errors::External(
"XPU kernel scale occur error in adamw error code ", r, "XPU kernel scale occur error in adam error code ", r,
XPUAPIErrorMsg[r])); XPUAPIErrorMsg[r]));
xpu_wait(dev_ctx.x_context()->xpu_stream);
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
if (grad->rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
std::vector<int64_t> cpu_rows(grad->rows().begin(), grad->rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
is_strict_sorted = false;
break;
} }
}
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = grad;
} else {
scatter::MergeAdd<platform::XPUDeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::XPUDeviceContext>(),
*grad, &tmp_grad_merge, true);
xpu_wait(dev_ctx.x_context()->xpu_stream);
grad_merge_ptr = &tmp_grad_merge;
}
const T* beta1_pow_ptr = beta1_pow.template data<T>();
const T* beta2_pow_ptr = beta2_pow.template data<T>();
Tensor xpu_beta1_pow;
Tensor xpu_beta2_pow;
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
paddle::framework::TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx,
&xpu_beta1_pow);
paddle::framework::TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx,
&xpu_beta2_pow);
dev_ctx.Wait();
beta1_pow_ptr = xpu_beta1_pow.template data<T>();
beta2_pow_ptr = xpu_beta2_pow.template data<T>();
}
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
int row_count = grad_merge.rows().size();
std::vector<int> rows(row_count);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* xpu_rows = RAII_GUARD.alloc_l3_or_gm<int>(row_count);
std::vector<int64_t> merge_rows(grad_merge.rows().begin(),
grad_merge.rows().end());
for (size_t i = 0; i < grad_merge.rows().size(); ++i) {
rows[i] = static_cast<int>(merge_rows[i]);
}
xpu_wait(dev_ctx.x_context()->xpu_stream);
memory::Copy(ctx.GetPlace(), xpu_rows, platform::CPUPlace(), rows.data(),
row_count * sizeof(int));
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
auto ori_rows = param.numel() / row_numel;
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, int lazy_mode = static_cast<int>(ctx.Attr<bool>("lazy_mode"));
platform::errors::External( int r = xpu::sparse_adam(
"XPU API return wrong value[%d], please check " dev_ctx.x_context(), grad_data, mom1.template data<T>(),
"where Baidu Kunlun Card is properly installed.", mom2.template data<T>(), param.template data<T>(), beta1_pow_ptr,
r)); beta2_pow_ptr, lr.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace()), beta1, beta2,
epsilon, ori_rows, xpu_rows, row_numel, grad_merge.rows().size(),
lazy_mode);
PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU API return wrong value[%d],", r));
if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
if (beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
const float* beta1_pow_p = beta1_pow.template data<float>();
beta1_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta1 * beta1_pow_p[0];
const float* beta2_pow_p = beta2_pow.template data<float>();
beta2_pow_out->mutable_data<float>(platform::CPUPlace())[0] =
beta2 * beta2_pow_p[0];
} else {
float* beta1_pow_out_p =
beta1_pow_out->mutable_data<float>(ctx.GetPlace());
float* beta2_pow_out_p =
beta2_pow_out->mutable_data<float>(ctx.GetPlace());
int r =
xpu::scale(dev_ctx.x_context(), beta1_pow_ptr, beta1_pow_out_p,
beta1_pow.numel(), false, beta1, 0.0f);
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adam error code ", r,
XPUAPIErrorMsg[r]));
r = xpu::scale(dev_ctx.x_context(), beta2_pow_ptr, beta2_pow_out_p,
beta2_pow.numel(), false, beta2, 0.0f);
PADDLE_ENFORCE_EQ(
r, xpu::SUCCESS,
platform::errors::External(
"XPU kernel scale occur error in adam error code ", r,
XPUAPIErrorMsg[r]));
}
} }
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else { } else {
PADDLE_ENFORCE_EQ(1, 2, platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(1, 2, platform::errors::InvalidArgument(
"Variable type not supported by adam_op")); "Variable type not supported by adam_op"));
......
...@@ -216,6 +216,144 @@ def adam_step(inputs, attributes): ...@@ -216,6 +216,144 @@ def adam_step(inputs, attributes):
return param_out, moment1_out, moment2_out return param_out, moment1_out, moment2_out
def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad,
lazy_mode):
'''
Simulate one step of the adam optimizer
:param inputs: dict of inputs
:param attributes: dict of attributes
:return tuple: tuple of output param, moment1, moment2,
beta1 power accumulator and beta2 power accumulator
'''
param = inputs['Param']
# grad = inputs['Grad']
moment1 = inputs['Moment1']
moment2 = inputs['Moment2']
lr = inputs['LearningRate']
beta1_pow = inputs['Beta1Pow']
beta2_pow = inputs['Beta2Pow']
beta1 = attributes['beta1']
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
moment1_out = np.zeros(shape=[height, row_numel])
moment2_out = np.zeros(shape=[height, row_numel])
param_out = np.zeros(shape=[height, row_numel])
def update_row(row_id, update_value):
moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1
) * update_value
moment2_out[row_id] = beta2 * moment2[row_id] + (
1 - beta2) * np.square(update_value)
lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / (
np.sqrt(moment2_out[row_id]) + epsilon))
if lazy_mode:
for idx, row_id in enumerate(rows):
update_row(row_id, np_grad[idx])
else:
for row_id in range(param_out.shape[0]):
update_value = np.zeros(np_grad[0].shape).astype("float32")
if row_id in rows:
update_value = np_grad[rows.index(row_id)]
update_row(row_id, update_value)
return param_out, moment1_out, moment2_out
class TestSparseAdamOp(unittest.TestCase):
def setup(self, scope, place, lazy_mode):
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
beta1_pow = np.array([beta1**10]).astype("float32")
beta2_pow = np.array([beta2**10]).astype("float32")
height = 10
rows = [0, 4, 7]
self.rows = rows
row_numel = 12
self.row_numel = row_numel
self.dense_inputs = {
"Param": np.full((height, row_numel), 5.0).astype("float32"),
"Moment1": np.full((height, row_numel), 5.0).astype("float32"),
"Moment2": np.full((height, row_numel), 5.0).astype("float32"),
'Beta1Pow': beta1_pow,
'Beta2Pow': beta2_pow,
"LearningRate": np.full((1), 2.0).astype("float32")
}
self.init_output = np.full((height, row_numel), 0.0).astype("float32")
self.attrs = {
'epsilon': epsilon,
'beta1': beta1,
'beta2': beta2,
'min_row_size_to_use_multithread': 2
}
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(np_array, place)
self.sparse_inputs = ["Grad"]
param_out, mom1, mom2 = adam_step_sparse(self.dense_inputs, self.attrs,
height, rows, row_numel,
np_array, lazy_mode)
self.outputs = {
"ParamOut": param_out,
"Moment1Out": mom1,
"Moment2Out": mom2,
'Beta1PowOut': beta1_pow * beta1,
'Beta2PowOut': beta2_pow * beta2
}
def check_with_place(self, place, lazy_mode):
scope = core.Scope()
self.setup(scope, place, lazy_mode)
op_args = dict()
op_args['lazy_mode'] = lazy_mode
for key, np_array in self.dense_inputs.items():
var = scope.var(key).get_tensor()
var.set(np_array, place)
op_args[key] = key
for s in self.sparse_inputs:
op_args[s] = s
for s in self.outputs:
var = scope.var(s).get_tensor()
var.set(self.init_output, place)
op_args[s] = s
for k in self.attrs:
op_args[k] = self.attrs[k]
# create and run adam operator
adam_op = Operator("adam", **op_args)
adam_op.run(scope, place)
for key, np_array in self.outputs.items():
out_var = scope.var(key).get_tensor()
actual = np.array(out_var)
actual = actual.reshape([actual.size])
np_array = np_array.reshape([np_array.size])
for i in range(np_array.size):
self.assertLess((actual[i] - np_array[i]), 0.00001)
def test_sparse_adam(self):
xpu_version = core.get_xpu_device_version(0)
version_str = "xpu2" if xpu_version == core.XPUVersion.XPU2 else "xpu1"
if "xpu2" == version_str:
self.check_with_place(paddle.XPUPlace(0), False)
class TestAdamOpBetaVariable(OpTest): class TestAdamOpBetaVariable(OpTest):
def setUp(self): def setUp(self):
'''Test Adam Op with beta as Variable '''Test Adam Op with beta as Variable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册