diff --git a/paddle/operators/adagrad_op.cc b/paddle/operators/adagrad_op.cc index 052c793a01907abdc7784d1290f43543ae81bdb1..c83318a272302a474c37ce86365201acf56b9cad 100644 --- a/paddle/operators/adagrad_op.cc +++ b/paddle/operators/adagrad_op.cc @@ -105,48 +105,18 @@ struct SparseAdagradFunctor { const framework::Tensor& learning_rate, T epsilon, framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) - auto grad_rows = grad.rows(); - std::set row_set(grad_rows.begin(), grad_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); - auto grad_width = grad.value().dims()[1]; - std::unique_ptr grad_merge{ - new framework::SelectedRows()}; - grad_merge->set_rows(merge_rows); - grad_merge->set_height(grad.height()); - grad_merge->mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), grad_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, grad_merge->mutable_value(), 0.0); - - auto* grad_merge_data = grad_merge->mutable_value()->data(); - auto* grad_data = grad.value().data(); - - for (size_t i = 0; i < grad_rows.size(); i++) { - size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]); - for (int64_t j = 0; j < grad_width; j++) { - grad_merge_data[grad_merge_i * grad_width + j] += - grad_data[i * grad_width + j]; - } - } + math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto& merge_rows = grad_merge.rows(); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); // 2. m += g_m * g_m - std::unique_ptr grad_square{ - new framework::SelectedRows()}; - grad_square->set_rows(grad_merge->rows()); - grad_square->set_height(grad_merge->height()); - grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), - context.GetPlace()); - auto gs = - framework::EigenVector::Flatten(*(grad_square->mutable_value())); - auto gm = framework::EigenVector::Flatten(grad_merge->value()); - gs.device(*context.eigen_device()) = gm * gm; + math::scatter::Mul sqare_func; + auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; - functor(context, *grad_square, moment); + functor(context, grad_square, moment); // 3. update parameter auto* lr = learning_rate.data(); diff --git a/paddle/operators/adagrad_op.cu b/paddle/operators/adagrad_op.cu index 585b2d92894af65b8ed15a596f0377fdcf564cfa..86b3dd860d9a647f6fc0ec1d699f08158772b306 100644 --- a/paddle/operators/adagrad_op.cu +++ b/paddle/operators/adagrad_op.cu @@ -78,51 +78,17 @@ struct SparseAdagradFunctor { const framework::Tensor& learning_rate, T epsilon, framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) - auto grad_rows = grad.rows(); - std::set row_set(grad_rows.begin(), grad_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); - auto grad_width = grad.value().dims()[1]; - std::unique_ptr grad_merge{ - new framework::SelectedRows()}; - grad_merge->set_rows(merge_rows); - grad_merge->set_height(grad.height()); - grad_merge->mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), grad_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, grad_merge->mutable_value(), 0.0); - - auto* grad_merge_data = grad_merge->mutable_value()->data(); - auto* grad_data = grad.value().data(); - - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid1(1, grad_rows.size()); - - MergeGradKernel< - T, 256><<(context) - .stream()>>>(grad_data, grad.rows().data(), - grad_merge_data, grad_merge->rows().data(), - grad_merge->rows().size(), grad_width); - + math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); + auto& merge_rows = grad_merge.rows; // 2. m += g_m * g_m - std::unique_ptr grad_square{ - new framework::SelectedRows()}; - grad_square->set_rows(grad_merge->rows()); - grad_square->set_height(grad_merge->height()); - grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), - context.GetPlace()); - auto gs = - framework::EigenVector::Flatten(*(grad_square->mutable_value())); - auto gm = framework::EigenVector::Flatten(grad_merge->value()); - gs.device(*context.eigen_device()) = gm * gm; + math::scatter::Mul sqare_func; + auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; - functor(context, *grad_square, moment); + functor(context, grad_square, moment); // 3. update parameter auto* lr = learning_rate.data(); diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index 5facd0112f14baedec46ce2d72c9456c1dc0dfe2..3c4148ccc0a7d5e2a2ef095cf8f639db09be0fc2 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -16,11 +16,14 @@ limitations under the License. */ #include // for sqrt in CPU and CUDA #include "paddle/framework/op_registry.h" #include "paddle/operators/detail/safe_ref.h" +#include "paddle/operators/math/selected_rows_functor.h" #include "paddle/platform/for_range.h" namespace paddle { namespace operators { +namespace scatter = paddle::operators::math::scatter; + template struct AdamFunctor { T beta1_; @@ -134,8 +137,6 @@ struct SparseAdamFunctor { mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); - // IMPORTANT: - // FIXME(typhoonzero): row id may be duplicate moment1_out_[rows_[i] * row_numel_ + j] = mom1; moment2_out_[rows_[i] * row_numel_ + j] = mom2; param_out_[rows_[i] * row_numel_ + j] = p; @@ -191,10 +192,14 @@ class AdamOpKernel : public framework::OpKernel { } else if (grad_var->IsType()) { auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); - auto& grad_tensor = grad.value(); + // merge duplicated rows if any. + scatter::MergeAdd merge_func; + auto grad_merge = + merge_func(ctx.template device_context(), grad); + auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); - auto* rows = grad.rows().data(); - auto row_numel = grad_tensor.numel() / grad.rows().size(); + auto* rows = grad_merge.rows().data(); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); SparseAdamFunctor functor( beta1, beta2, epsilon, beta1_pow.template data(), @@ -206,7 +211,7 @@ class AdamOpKernel : public framework::OpKernel { param_out.template mutable_data(ctx.GetPlace()), rows, row_numel); platform::ForRange for_range( static_cast(ctx.device_context()), - grad.rows().size()); + grad_merge.rows().size()); for_range(functor); } else { PADDLE_THROW("Variable type not supported by adam_op"); diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index 21418ba4b0201e50302edea66d016e9464891f36..c9f3c10c61700d7faf91085c3f3d0564662873ae 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/selected_rows_functor.h" +#include + #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/selected_rows_functor.h" namespace paddle { namespace operators { @@ -193,27 +195,25 @@ size_t FindPos(const std::vector& rows, int64_t value) { template struct MergeAdd { - void operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input, - framework::SelectedRows* out) { + framework::SelectedRows operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; auto input_rows = input.rows(); std::set row_set(input_rows.begin(), input_rows.end()); std::vector merge_rows(row_set.begin(), row_set.end()); auto input_width = input.value().dims()[1]; - // std::unique_ptr out{ - // new framework::SelectedRows()}; - out->set_rows(merge_rows); - out->set_height(input.height()); - out->mutable_value()->mutable_data( + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( framework::make_ddim( {static_cast(merge_rows.size()), input_width}), context.GetPlace()); math::SetConstant constant_functor; - constant_functor(context, out->mutable_value(), 0.0); + constant_functor(context, out.mutable_value(), 0.0); - auto* out_data = out->mutable_value()->data(); + auto* out_data = out.mutable_value()->data(); auto* input_data = input.value().data(); for (size_t i = 0; i < input_rows.size(); i++) { @@ -222,6 +222,74 @@ struct MergeAdd { out_data[out_i * input_width + j] += input_data[i * input_width + j]; } } + return out; + } +}; + +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +struct UpdateToTensor { + framework::Tensor operator()(const platform::CPUDeviceContext& context, + const ScatterOps& op, + const framework::SelectedRows& input1, + framework::Tensor* input2) { + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::ADD: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUB: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] -= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUBBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] - + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + case ScatterOps::MUL: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] *= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIV: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] /= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIVBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] / + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + } } }; diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index b2c0fe7bc3da725a9175930657bb4fc8f6da3764..48413403db513056299fc578a64ca7a63aea12c3 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -252,27 +252,26 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, template struct MergeAdd { - void operator()(const platform::GPUDeviceContext& context, - const framework::SelectedRows& input, - framework::SelectedRows* out) { + framework::SelectedRows operator()(const platform::GPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; auto input_rows = input.rows(); std::set row_set(input_rows.begin(), input_rows.end()); std::vector merge_rows(row_set.begin(), row_set.end()); auto input_width = input.value().dims()[1]; - // std::unique_ptr out{ - // new framework::SelectedRows()}; - out->set_rows(merge_rows); - out->set_height(input.height()); - out->mutable_value()->mutable_data( + + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( framework::make_ddim( {static_cast(merge_rows.size()), input_width}), context.GetPlace()); math::SetConstant constant_functor; - constant_functor(context, out->mutable_value(), 0.0); + constant_functor(context, out.mutable_value(), 0.0); - auto* out_data = out->mutable_value()->data(); + auto* out_data = out.mutable_value()->data(); auto* input_data = input.value().data(); const int block_size = 256; @@ -283,11 +282,96 @@ struct MergeAdd { T, 256><<(context) .stream()>>>(input_data, input.rows().data(), out_data, - out->rows().data(), out->rows().size(), + out.rows().data(), out.rows().size(), input_width); + return out; } }; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +__global__ void UpdateToTensorKernel(const T* selected_rows, + const int64_t* rows, const ScatterOps& op, + T* tensor_out, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index]; + } + break; + case ScatterOps::ADD: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] += selected_rows[index]; + } + break; + case ScatterOps::SUB: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] -= selected_rows[index]; + } + break; + case ScatterOps::SUBBY: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index] - tensor_out[index]; + } + break; + case ScatterOps::MUL: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] *= selected_rows[index]; + } + break; + case ScatterOps::DIV: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] /= selected_rows[index]; + } + break; + case ScatterOps::DIVBY: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index] / tensor_out[index]; + } + break; + } +} + +template +struct UpdateToTensor { + framework::Tensor operator()(const platform::GPUDeviceContext& context, + const ScatterOps& op, + const framework::SelectedRows& input1, + framework::Tensor* input2) { + // NOTE: Use SelectedRowsAddToTensor for better performance + // no additional MergeAdd called. + auto merged_in1 = MergeAdd()(context, input1); + + auto in1_height = merged_in1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = merged_in1.value(); + auto& in1_rows = merged_in1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + dim3 threads(PADDLE_CUDA_NUM_THREADS, 1); + dim3 grid(1, in1_rows.size()); + UpdateToTensorKernel< + T, PADDLE_CUDA_NUM_THREADS><<>>( + in1_data, in1_rows.data(), op, in2_data, in1_row_numel); + } +}; } // namespace scatter } // namespace math } // namespace operators diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index eecd5e5362bbb398933a3ca287a2c25997f0bcfd..d4bef72980f8374c1519c1774b72b67d67397a2c 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -16,6 +16,10 @@ limitations under the License. */ #include "paddle/framework/selected_rows.h" #include "paddle/platform/device_context.h" +#define INLINE_FOR2(sizei, sizej) \ + for (int64_t i = 0; i < sizei; i++) \ + for (int64_t j = 0; j < sizej; j++) + namespace paddle { namespace operators { namespace math { @@ -55,50 +59,76 @@ struct SelectedRowsAddToTensor { namespace scatter { // functors for manuplating SelectedRows data - template struct MergeAdd { // unary functor, merge by adding duplicated rows in // the input SelectedRows object. - void operator()(const DeviceContext& context, - const framework::SelectedRows& input, - framework::SelectedRows* out); + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input); }; template struct Add { - void operator()(const DeviceContext& context, - const framework::SelectedRows& input1, - const framework::SelectedRows& input2, - framework::SelectedRows* out) { - out->set_rows(input1.rows()); - out->set_height(input1.height()); - out->mutable_value()->mutable_data(input1.value().dims(), - context.GetPlace()); - auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); auto e_in1 = framework::EigenVector::Flatten(input1.value()); auto e_in2 = framework::EigenVector::Flatten(input2.value()); e_out.device(*context.eigen_device()) = e_in1 + e_in2; + return out; } }; template struct Mul { - void operator()(const DeviceContext& context, - const framework::SelectedRows& input1, - const framework::SelectedRows& input2, - framework::SelectedRows* out) { - out->set_rows(input1.rows()); - out->set_height(input1.height()); - out->mutable_value()->mutable_data(input1.value().dims(), - context.GetPlace()); - auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); + // multiply two SelectedRows + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); auto e_in1 = framework::EigenVector::Flatten(input1.value()); auto e_in2 = framework::EigenVector::Flatten(input2.value()); e_out.device(*context.eigen_device()) = e_in1 * e_in2; + return out; + } + // multiply scalar to SelectedRows + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const T input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + e_out.device(*context.eigen_device()) = input2 * e_in1; + return out; } }; +enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; + +// out = seleted_rows_in / tensor +template +struct UpdateToTensor { + framework::Tensor operator()(const DeviceContext& context, + const ScatterOps& op, + const framework::SelectedRows& input1, + framework::Tensor* input2); +}; + } // namespace scatter } // namespace math } // namespace operators diff --git a/python/paddle/v2/fluid/tests/test_adam_op.py b/python/paddle/v2/fluid/tests/test_adam_op.py index 3758ca457e5db19354a0ac39d9c6683b918ad89c..7dbc2fa0858a68c5da9e8d48dcb187494357e940 100644 --- a/python/paddle/v2/fluid/tests/test_adam_op.py +++ b/python/paddle/v2/fluid/tests/test_adam_op.py @@ -285,7 +285,6 @@ class TestSparseAdamOp(unittest.TestCase): j = 0 while j < self.row_numel: pos = row_id * self.row_numel + j - print(actual[pos] - np_array[pos]) / actual[pos] self.assertLess((actual[pos] - np_array[pos]) / actual[pos], 0.00001) j += 1