diff --git a/paddle/operators/adagrad_op.cc b/paddle/operators/adagrad_op.cc index 052c793a01907abdc7784d1290f43543ae81bdb1..c83318a272302a474c37ce86365201acf56b9cad 100644 --- a/paddle/operators/adagrad_op.cc +++ b/paddle/operators/adagrad_op.cc @@ -105,48 +105,18 @@ struct SparseAdagradFunctor { const framework::Tensor& learning_rate, T epsilon, framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) - auto grad_rows = grad.rows(); - std::set row_set(grad_rows.begin(), grad_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); - auto grad_width = grad.value().dims()[1]; - std::unique_ptr grad_merge{ - new framework::SelectedRows()}; - grad_merge->set_rows(merge_rows); - grad_merge->set_height(grad.height()); - grad_merge->mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), grad_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, grad_merge->mutable_value(), 0.0); - - auto* grad_merge_data = grad_merge->mutable_value()->data(); - auto* grad_data = grad.value().data(); - - for (size_t i = 0; i < grad_rows.size(); i++) { - size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]); - for (int64_t j = 0; j < grad_width; j++) { - grad_merge_data[grad_merge_i * grad_width + j] += - grad_data[i * grad_width + j]; - } - } + math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto& merge_rows = grad_merge.rows(); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); // 2. m += g_m * g_m - std::unique_ptr grad_square{ - new framework::SelectedRows()}; - grad_square->set_rows(grad_merge->rows()); - grad_square->set_height(grad_merge->height()); - grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), - context.GetPlace()); - auto gs = - framework::EigenVector::Flatten(*(grad_square->mutable_value())); - auto gm = framework::EigenVector::Flatten(grad_merge->value()); - gs.device(*context.eigen_device()) = gm * gm; + math::scatter::Mul sqare_func; + auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; - functor(context, *grad_square, moment); + functor(context, grad_square, moment); // 3. update parameter auto* lr = learning_rate.data(); diff --git a/paddle/operators/adagrad_op.cu b/paddle/operators/adagrad_op.cu index 75bc7affd6c78beb783e01682b4538f2c259df26..4e579387924a5b0499f29609bc6b1322030a3c0d 100644 --- a/paddle/operators/adagrad_op.cu +++ b/paddle/operators/adagrad_op.cu @@ -78,62 +78,30 @@ struct SparseAdagradFunctor { const framework::Tensor& learning_rate, T epsilon, framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) - auto grad_rows = grad.rows(); - std::set row_set(grad_rows.begin(), grad_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); - auto grad_width = grad.value().dims()[1]; - std::unique_ptr grad_merge{ - new framework::SelectedRows()}; - grad_merge->set_rows(merge_rows); - grad_merge->set_height(grad.height()); - grad_merge->mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), grad_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, grad_merge->mutable_value(), 0.0); - - auto* grad_merge_data = grad_merge->mutable_value()->data(); - auto* grad_data = grad.value().data(); - - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid1(1, grad_rows.size()); - - MergeGradKernel< - T, 256><<(context) - .stream()>>>(grad_data, grad.rows().data(), - grad_merge_data, grad_merge->rows().data(), - grad_merge->rows().size(), grad_width); - + math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); + auto& merge_rows = grad_merge.rows(); // 2. m += g_m * g_m - std::unique_ptr grad_square{ - new framework::SelectedRows()}; - grad_square->set_rows(grad_merge->rows()); - grad_square->set_height(grad_merge->height()); - grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), - context.GetPlace()); - auto gs = - framework::EigenVector::Flatten(*(grad_square->mutable_value())); - auto gm = framework::EigenVector::Flatten(grad_merge->value()); - gs.device(*context.eigen_device()) = gm * gm; + math::scatter::Mul sqare_func; + auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; - functor(context, *grad_square, moment); + functor(context, grad_square, moment); // 3. update parameter auto* lr = learning_rate.data(); auto* param_data = param->data(); auto* moment_data = moment->data(); + const int block_size = 256; + dim3 threads(block_size, 1); dim3 grid2(1, merge_rows.size()); SparseAdagradFunctorKernel< T, 256><<(context) - .stream()>>>(grad_merge_data, grad_merge->rows().data(), + .stream()>>>(grad_merge_data, grad_merge.rows().data(), lr, param_data, moment_data, grad_width, epsilon); } diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index c4e2c8bb88ec9c74bd782570c10fb217178c8e48..9cc34bdded780e61e8700eb4fa4a295c84fb48bc 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -16,11 +16,14 @@ limitations under the License. */ #include // for sqrt in CPU and CUDA #include "paddle/framework/op_registry.h" #include "paddle/operators/detail/safe_ref.h" +#include "paddle/operators/math/selected_rows_functor.h" #include "paddle/platform/for_range.h" namespace paddle { namespace operators { +namespace scatter = paddle::operators::math::scatter; + template struct AdamFunctor { T beta1_; @@ -79,6 +82,69 @@ struct AdamFunctor { } }; +template +struct SparseAdamFunctor { + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* lr_; + const T* grad_; + const T* param_; + T* param_out_; + + const int64_t* rows_; + int64_t row_numel_; + + SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, + const T* beta2_pow, const T* mom1, T* mom1_out, + const T* mom2, T* mom2_out, const T* lr, const T* grad, + const T* param, T* param_out, const int64_t* rows, + int64_t row_numel) + : beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + lr_(lr), + grad_(grad), + param_(param), + param_out_(param_out), + rows_(rows), + row_numel_(row_numel) {} + + inline HOSTDEVICE void operator()(size_t i) const { + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + for (int64_t j = 0; j < row_numel_; ++j) { + T g = grad_[i * row_numel_ + j]; + T mom1 = moment1_[rows_[i] * row_numel_ + j]; + T mom2 = moment2_[rows_[i] * row_numel_ + j]; + T lr = *lr_; + T p = param_[rows_[i] * row_numel_ + j]; + + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + + moment1_out_[rows_[i] * row_numel_ + j] = mom1; + moment2_out_[rows_[i] * row_numel_ + j] = mom2; + param_out_[rows_[i] * row_numel_ + j] = p; + } // for col id + } +}; + template class AdamOpKernel : public framework::OpKernel { public: @@ -90,7 +156,8 @@ class AdamOpKernel : public framework::OpKernel { T beta2 = static_cast(ctx.Attr("beta2")); T epsilon = static_cast(ctx.Attr("epsilon")); auto& param = Ref(ctx.Input("Param"), "Must set Param"); - auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + // auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + auto* grad_var = ctx.InputVar("Grad"); auto& mom1 = Ref(ctx.Input("Moment1"), "Must set Moment1"); auto& mom2 = Ref(ctx.Input("Moment2"), "Must set Moment2"); auto& lr = @@ -108,18 +175,48 @@ class AdamOpKernel : public framework::OpKernel { auto& mom2_out = Ref(ctx.Output("Moment2Out"), "Must set Moment1Out"); - AdamFunctor functor(beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad.template data(), - param.template data(), - param_out.template mutable_data(ctx.GetPlace())); - platform::ForRange for_range( - static_cast(ctx.device_context()), param.numel()); - for_range(functor); + if (grad_var->IsType()) { + auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + AdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad.template data(), + param.template data(), + param_out.template mutable_data(ctx.GetPlace())); + platform::ForRange for_range( + static_cast(ctx.device_context()), + param.numel()); + for_range(functor); + } else if (grad_var->IsType()) { + auto& grad = + Ref(ctx.Input("Grad"), "Must set Grad"); + // merge duplicated rows if any. + scatter::MergeAdd merge_func; + auto grad_merge = + merge_func(ctx.template device_context(), grad); + auto& grad_tensor = grad_merge.value(); + const T* grad_data = grad_tensor.template data(); + auto* rows = grad_merge.rows().data(); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); + + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel); + platform::ForRange for_range( + static_cast(ctx.device_context()), + grad_merge.rows().size()); + for_range(functor); + } else { + PADDLE_THROW("Variable type not supported by adam_op"); + } } }; diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index ab758d1e7fd8ab361948b28e8cb735b9a742a339..8a1ebb58c26578f076bf243adfbd51d10c682b99 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/selected_rows_functor.h" +#include + #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/selected_rows_functor.h" namespace paddle { namespace operators { @@ -179,6 +181,118 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +// This is a separated namespace for manipulate SelectedRows typed +// data. Like merge duplicated rows, adding two SelectedRows etc. +// +// Another group of functors is called "scatter updates", which means +// use SelectedRows to update a dense tensor with different Ops, like +// add or mul. +namespace scatter { + +size_t FindPos(const std::vector& rows, int64_t value) { + return std::find(rows.begin(), rows.end(), value) - rows.begin(); +} + +template +struct MergeAdd { + framework::SelectedRows operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = FindPos(merge_rows, input_rows[i]); + for (int64_t j = 0; j < input_width; j++) { + out_data[out_i * input_width + j] += input_data[i * input_width + j]; + } + } + return out; + } +}; + +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +struct UpdateToTensor { + void operator()(const platform::CPUDeviceContext& context, + const ScatterOps& op, const framework::SelectedRows& input1, + framework::Tensor* input2) { + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::ADD: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUB: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] -= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUBBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] - + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + case ScatterOps::MUL: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] *= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIV: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] /= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIVBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] / + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + } + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index 9fddd97a36f7fdb6628d6eeb192cb216fdae3e5b..0ee456f9bc61436bd0f2f8ef20dd1654e7e56d56 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/selected_rows_functor.h" #include "paddle/platform/cuda_helper.h" @@ -222,6 +224,157 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; + +namespace scatter { + +template +__global__ void MergeAddKernel(const T* input, const int64_t* input_rows, + T* out, const int64_t* out_rows, + size_t out_rows_size, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + __shared__ size_t out_idx; + + if (tid == 0) { + for (size_t i = 0; i < out_rows_size; i++) { + if (input_rows[ty] == out_rows[i]) { + out_idx = i; + } + } + } + + __syncthreads(); + + input += ty * row_numel; + out += out_idx * row_numel; + for (int index = tid; index < row_numel; index += block_size) { + paddle::platform::CudaAtomicAdd(out + index, input[index]); + } +} + +template +struct MergeAdd { + framework::SelectedRows operator()(const platform::CUDADeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid1(1, input_rows.size()); + + MergeAddKernel< + T, 256><<(context) + .stream()>>>(input_data, input.rows().data(), out_data, + out.rows().data(), out.rows().size(), + input_width); + return out; + } +}; + +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +__global__ void UpdateToTensorKernel(const T* selected_rows, + const int64_t* rows, const ScatterOps& op, + T* tensor_out, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index]; + } + break; + case ScatterOps::ADD: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] += selected_rows[index]; + } + break; + case ScatterOps::SUB: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] -= selected_rows[index]; + } + break; + case ScatterOps::SUBBY: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index] - tensor_out[index]; + } + break; + case ScatterOps::MUL: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] *= selected_rows[index]; + } + break; + case ScatterOps::DIV: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] /= selected_rows[index]; + } + break; + case ScatterOps::DIVBY: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index] / tensor_out[index]; + } + break; + } +} + +template +struct UpdateToTensor { + void operator()(const platform::CUDADeviceContext& context, + const ScatterOps& op, const framework::SelectedRows& input1, + framework::Tensor* input2) { + // NOTE: Use SelectedRowsAddToTensor for better performance + // no additional MergeAdd called. + MergeAdd merge_func; + auto merged_in1 = merge_func(context, input1); + + auto in1_height = merged_in1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = merged_in1.value(); + auto& in1_rows = merged_in1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.template data(); + auto* in2_data = input2->data(); + + dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1); + dim3 grid(1, in1_rows.size()); + UpdateToTensorKernel<<< + grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op, + in2_data, in1_row_numel); + } +}; +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index 1149075abf16547a120ac8928c45b4972409fc72..09d4631905f90f78772368ad71b11826877bdc34 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/selected_rows.h" #include "paddle/platform/device_context.h" +#define INLINE_FOR2(sizei, sizej) \ + for (int64_t i = 0; i < sizei; i++) \ + for (int64_t j = 0; j < sizej; j++) + namespace paddle { namespace operators { namespace math { @@ -52,6 +57,78 @@ struct SelectedRowsAddToTensor { framework::Tensor* input2); }; +namespace scatter { +// functors for manuplating SelectedRows data +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input); +}; + +template +struct Add { + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); + e_out.device(*context.eigen_device()) = e_in1 + e_in2; + return out; + } +}; + +template +struct Mul { + // multiply two SelectedRows + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); + e_out.device(*context.eigen_device()) = e_in1 * e_in2; + return out; + } + // multiply scalar to SelectedRows + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const T input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + e_out.device(*context.eigen_device()) = input2 * e_in1; + return out; + } +}; + +enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; + +// out = seleted_rows_in / tensor +template +struct UpdateToTensor { + void operator()(const DeviceContext& context, const ScatterOps& op, + const framework::SelectedRows& input1, + framework::Tensor* input2); +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_adam_op.py b/python/paddle/v2/fluid/tests/test_adam_op.py index a0d6655d4cbcff8ed3d55df0f4e68fc6591fbb11..7dbc2fa0858a68c5da9e8d48dcb187494357e940 100644 --- a/python/paddle/v2/fluid/tests/test_adam_op.py +++ b/python/paddle/v2/fluid/tests/test_adam_op.py @@ -1,6 +1,8 @@ import unittest import numpy as np from op_test import OpTest +from paddle.v2.fluid import core +from paddle.v2.fluid.op import Operator class TestAdamOp1(OpTest): @@ -176,5 +178,124 @@ def adam_step(inputs, attributes): return param_out, moment1_out, moment2_out +def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): + ''' + Simulate one step of the adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + # grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + + moment1_out = np.zeros(shape=[height, row_numel]) + moment2_out = np.zeros(shape=[height, row_numel]) + param_out = np.zeros(shape=[height, row_numel]) + + for idx, row_id in enumerate(rows): + moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1 + ) * np_grad[idx] + moment2_out[row_id] = beta2 * moment2[row_id] + ( + 1 - beta2) * np.square(np_grad[idx]) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) + param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / ( + np.sqrt(moment2_out[row_id]) + epsilon)) + return param_out, moment1_out, moment2_out + + +class TestSparseAdamOp(unittest.TestCase): + def setup(self, scope, place): + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + + height = 10 + rows = [0, 4, 7] + self.rows = rows + row_numel = 12 + self.row_numel = row_numel + self.dense_inputs = { + "Param": np.full((height, row_numel), 5.0).astype("float32"), + "Moment1": np.full((height, row_numel), 5.0).astype("float32"), + "Moment2": np.full((height, row_numel), 5.0).astype("float32"), + 'Beta1Pow': np.array([beta1**10]).astype("float32"), + 'Beta2Pow': np.array([beta2**10]).astype("float32"), + "LearningRate": np.full((1), 2.0).astype("float32") + } + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + self.sparse_inputs = ["Grad"] + + param_out, mom1, mom2 = adam_step_sparse( + self.dense_inputs, self.attrs, height, rows, row_numel, np_array) + self.outputs = { + "ParamOut": param_out, + "Moment1Out": mom1, + "Moment2Out": mom2 + } + + def check_with_place(self, place): + scope = core.Scope() + self.setup(scope, place) + + op_args = dict() + for key, np_array in self.dense_inputs.iteritems(): + var = scope.var(key).get_tensor() + var.set(np_array, place) + op_args[key] = key + for s in self.sparse_inputs: + op_args[s] = s + for s in self.outputs: + var = scope.var(s).get_tensor() + var.set(self.outputs[s], place) + op_args[s] = s + for k in self.attrs: + op_args[k] = self.attrs[k] + + # create and run sgd operator + adam_op = Operator("adam", **op_args) + adam_op.run(scope, place) + + for key, np_array in self.outputs.iteritems(): + out_var = scope.var(key).get_tensor() + actual = np.array(out_var) + actual = actual.reshape([actual.size]) + np_array = np_array.reshape([np_array.size]) + for idx, row_id in enumerate(self.rows): + j = 0 + while j < self.row_numel: + pos = row_id * self.row_numel + j + self.assertLess((actual[pos] - np_array[pos]) / actual[pos], + 0.00001) + j += 1 + + def test_sparse_sgd(self): + places = [core.CPUPlace()] + if core.is_compile_gpu(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main()