提交 641b4c0f 编写于 作者: T typhoonzero

wip

上级 74b12288
......@@ -105,48 +105,18 @@ struct SparseAdagradFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows)
auto grad_rows = grad.rows();
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto grad_width = grad.value().dims()[1];
std::unique_ptr<framework::SelectedRows> grad_merge{
new framework::SelectedRows()};
grad_merge->set_rows(merge_rows);
grad_merge->set_height(grad.height());
grad_merge->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
auto* grad_data = grad.value().data<T>();
for (size_t i = 0; i < grad_rows.size(); i++) {
size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]);
for (int64_t j = 0; j < grad_width; j++) {
grad_merge_data[grad_merge_i * grad_width + j] +=
grad_data[i * grad_width + j];
}
}
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
auto grad_merge = merge_func(context, grad);
auto& merge_rows = grad_merge.rows();
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
// 2. m += g_m * g_m
std::unique_ptr<framework::SelectedRows> grad_square{
new framework::SelectedRows()};
grad_square->set_rows(grad_merge->rows());
grad_square->set_height(grad_merge->height());
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.eigen_device()) = gm * gm;
math::scatter::Mul<platform::CPUDeviceContext, T> sqare_func;
auto grad_square = sqare_func(context, grad_merge, grad_merge);
math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor;
functor(context, *grad_square, moment);
functor(context, grad_square, moment);
// 3. update parameter
auto* lr = learning_rate.data<T>();
......
......@@ -78,51 +78,17 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows)
auto grad_rows = grad.rows();
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto grad_width = grad.value().dims()[1];
std::unique_ptr<framework::SelectedRows> grad_merge{
new framework::SelectedRows()};
grad_merge->set_rows(merge_rows);
grad_merge->set_height(grad.height());
grad_merge->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
auto* grad_data = grad.value().data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid1(1, grad_rows.size());
MergeGradKernel<
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_data, grad.rows().data(),
grad_merge_data, grad_merge->rows().data(),
grad_merge->rows().size(), grad_width);
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
auto grad_merge = merge_func(context, grad);
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
auto& merge_rows = grad_merge.rows;
// 2. m += g_m * g_m
std::unique_ptr<framework::SelectedRows> grad_square{
new framework::SelectedRows()};
grad_square->set_rows(grad_merge->rows());
grad_square->set_height(grad_merge->height());
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.eigen_device()) = gm * gm;
math::scatter::Mul<platform::CPUDeviceContext, T> sqare_func;
auto grad_square = sqare_func(context, grad_merge, grad_merge);
math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor;
functor(context, *grad_square, moment);
functor(context, grad_square, moment);
// 3. update parameter
auto* lr = learning_rate.data<T>();
......
......@@ -16,11 +16,14 @@ limitations under the License. */
#include <math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/for_range.h"
namespace paddle {
namespace operators {
namespace scatter = paddle::operators::math::scatter;
template <typename T>
struct AdamFunctor {
T beta1_;
......@@ -134,8 +137,6 @@ struct SparseAdamFunctor {
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// IMPORTANT:
// FIXME(typhoonzero): row id may be duplicate
moment1_out_[rows_[i] * row_numel_ + j] = mom1;
moment2_out_[rows_[i] * row_numel_ + j] = mom2;
param_out_[rows_[i] * row_numel_ + j] = p;
......@@ -191,10 +192,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad =
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
auto& grad_tensor = grad.value();
// merge duplicated rows if any.
scatter::MergeAdd<DeviceContext, T> merge_func;
auto grad_merge =
merge_func(ctx.template device_context<DeviceContext>(), grad);
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
auto* rows = grad.rows().data();
auto row_numel = grad_tensor.numel() / grad.rows().size();
auto* rows = grad_merge.rows().data();
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
......@@ -206,7 +211,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
grad.rows().size());
grad_merge.rows().size());
for_range(functor);
} else {
PADDLE_THROW("Variable type not supported by adam_op");
......
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/selected_rows_functor.h"
#include <set>
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
......@@ -193,27 +195,25 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* out) {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
// std::unique_ptr<framework::SelectedRows> out{
// new framework::SelectedRows()};
out->set_rows(merge_rows);
out->set_height(input.height());
out->mutable_value()->mutable_data<T>(
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, out->mutable_value(), 0.0);
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out->mutable_value()->data<T>();
auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
for (size_t i = 0; i < input_rows.size(); i++) {
......@@ -222,6 +222,74 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
out_data[out_i * input_width + j] += input_data[i * input_width + j];
}
}
return out;
}
};
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template <typename T>
struct UpdateToTensor<platform::CPUDeviceContext, T> {
framework::Tensor operator()(const platform::CPUDeviceContext& context,
const ScatterOps& op,
const framework::SelectedRows& input1,
framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
// FIXME(typhoonzero): use macro fix the below messy code.
switch (op) {
case ScatterOps::ASSIGN:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::ADD:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::SUB:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] -=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::SUBBY:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j] -
input2_data[in1_rows[i] * in1_row_numel + j];
break;
case ScatterOps::MUL:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] *=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::DIV:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] /=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::DIVBY:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j] /
input2_data[in1_rows[i] * in1_row_numel + j];
break;
}
}
};
......
......@@ -252,27 +252,26 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
template <typename T>
struct MergeAdd<platform::GPUDeviceContext, T> {
void operator()(const platform::GPUDeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* out) {
framework::SelectedRows operator()(const platform::GPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
// std::unique_ptr<framework::SelectedRows> out{
// new framework::SelectedRows()};
out->set_rows(merge_rows);
out->set_height(input.height());
out->mutable_value()->mutable_data<T>(
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out->mutable_value(), 0.0);
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out->mutable_value()->data<T>();
auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
const int block_size = 256;
......@@ -283,11 +282,96 @@ struct MergeAdd<platform::GPUDeviceContext, T> {
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input_data, input.rows().data(), out_data,
out->rows().data(), out->rows().size(),
out.rows().data(), out.rows().size(),
input_width);
return out;
}
};
template struct MergeAdd<platform::GPUDeviceContext, float>;
template struct MergeAdd<platform::GPUDeviceContext, double>;
template struct MergeAdd<platform::GPUDeviceContext, int>;
template struct MergeAdd<platform::GPUDeviceContext, int64_t>;
template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
const int64_t* rows, const ScatterOps& op,
T* tensor_out, int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
// FIXME(typhoonzero): use macro fix the below messy code.
switch (op) {
case ScatterOps::ASSIGN:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index];
}
break;
case ScatterOps::ADD:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] += selected_rows[index];
}
break;
case ScatterOps::SUB:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] -= selected_rows[index];
}
break;
case ScatterOps::SUBBY:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index] - tensor_out[index];
}
break;
case ScatterOps::MUL:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] *= selected_rows[index];
}
break;
case ScatterOps::DIV:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] /= selected_rows[index];
}
break;
case ScatterOps::DIVBY:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index] / tensor_out[index];
}
break;
}
}
template <typename T>
struct UpdateToTensor<platform::GPUDeviceContext, T> {
framework::Tensor operator()(const platform::GPUDeviceContext& context,
const ScatterOps& op,
const framework::SelectedRows& input1,
framework::Tensor* input2) {
// NOTE: Use SelectedRowsAddToTensor for better performance
// no additional MergeAdd called.
auto merged_in1 = MergeAdd()(context, input1);
auto in1_height = merged_in1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = merged_in1.value();
auto& in1_rows = merged_in1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
dim3 threads(PADDLE_CUDA_NUM_THREADS, 1);
dim3 grid(1, in1_rows.size());
UpdateToTensorKernel<
T, PADDLE_CUDA_NUM_THREADS><<<grid, threads, 0, context.stream()>>>(
in1_data, in1_rows.data(), op, in2_data, in1_row_numel);
}
};
} // namespace scatter
} // namespace math
} // namespace operators
......
......@@ -16,6 +16,10 @@ limitations under the License. */
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
#define INLINE_FOR2(sizei, sizej) \
for (int64_t i = 0; i < sizei; i++) \
for (int64_t j = 0; j < sizej; j++)
namespace paddle {
namespace operators {
namespace math {
......@@ -55,50 +59,76 @@ struct SelectedRowsAddToTensor {
namespace scatter {
// functors for manuplating SelectedRows data
template <typename DeviceContext, typename T>
struct MergeAdd {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
void operator()(const DeviceContext& context,
const framework::SelectedRows& input,
framework::SelectedRows* out);
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input);
};
template <typename DeviceContext, typename T>
struct Add {
void operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* out) {
out->set_rows(input1.rows());
out->set_height(input1.height());
out->mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 + e_in2;
return out;
}
};
template <typename DeviceContext, typename T>
struct Mul {
void operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2,
framework::SelectedRows* out) {
out->set_rows(input1.rows());
out->set_height(input1.height());
out->mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out->mutable_value()));
// multiply two SelectedRows
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 * e_in2;
return out;
}
// multiply scalar to SelectedRows
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const T input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
e_out.device(*context.eigen_device()) = input2 * e_in1;
return out;
}
};
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor
template <typename DeviceContext, typename T>
struct UpdateToTensor {
framework::Tensor operator()(const DeviceContext& context,
const ScatterOps& op,
const framework::SelectedRows& input1,
framework::Tensor* input2);
};
} // namespace scatter
} // namespace math
} // namespace operators
......
......@@ -285,7 +285,6 @@ class TestSparseAdamOp(unittest.TestCase):
j = 0
while j < self.row_numel:
pos = row_id * self.row_numel + j
print(actual[pos] - np_array[pos]) / actual[pos]
self.assertLess((actual[pos] - np_array[pos]) / actual[pos],
0.00001)
j += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册