diff --git a/paddle/operators/adagrad_op.cu b/paddle/operators/adagrad_op.cu index 86b3dd860d9a647f6fc0ec1d699f08158772b306..fed2e29367de50b3c4d115ca22e2b6ed59b468bc 100644 --- a/paddle/operators/adagrad_op.cu +++ b/paddle/operators/adagrad_op.cu @@ -79,12 +79,12 @@ struct SparseAdagradFunctor { framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) auto grad_width = grad.value().dims()[1]; - math::scatter::MergeAdd merge_func; + math::scatter::MergeAdd merge_func; auto grad_merge = merge_func(context, grad); auto* grad_merge_data = grad_merge.mutable_value()->template data(); - auto& merge_rows = grad_merge.rows; + auto& merge_rows = grad_merge.rows(); // 2. m += g_m * g_m - math::scatter::Mul sqare_func; + math::scatter::Mul sqare_func; auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; @@ -95,11 +95,13 @@ struct SparseAdagradFunctor { auto* param_data = param->data(); auto* moment_data = moment->data(); + const int block_size = 256; + dim3 threads(block_size, 1); dim3 grid2(1, merge_rows.size()); SparseAdagradFunctorKernel< T, 256><<(context) - .stream()>>>(grad_merge_data, grad_merge->rows().data(), + .stream()>>>(grad_merge_data, grad_merge.rows().data(), lr, param_data, moment_data, grad_width, epsilon); } diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index c9f3c10c61700d7faf91085c3f3d0564662873ae..8a1ebb58c26578f076bf243adfbd51d10c682b99 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -233,10 +233,9 @@ template struct MergeAdd; template struct UpdateToTensor { - framework::Tensor operator()(const platform::CPUDeviceContext& context, - const ScatterOps& op, - const framework::SelectedRows& input1, - framework::Tensor* input2) { + void operator()(const platform::CPUDeviceContext& context, + const ScatterOps& op, const framework::SelectedRows& input1, + framework::Tensor* input2) { auto in1_height = input1.height(); auto in2_dims = input2->dims(); PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index 48413403db513056299fc578a64ca7a63aea12c3..0ee456f9bc61436bd0f2f8ef20dd1654e7e56d56 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/selected_rows_functor.h" #include "paddle/platform/cuda_helper.h" @@ -251,8 +253,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, } template -struct MergeAdd { - framework::SelectedRows operator()(const platform::GPUDeviceContext& context, +struct MergeAdd { + framework::SelectedRows operator()(const platform::CUDADeviceContext& context, const framework::SelectedRows& input) { framework::SelectedRows out; auto input_rows = input.rows(); @@ -288,10 +290,10 @@ struct MergeAdd { } }; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; -template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; template __global__ void UpdateToTensorKernel(const T* selected_rows, @@ -343,14 +345,14 @@ __global__ void UpdateToTensorKernel(const T* selected_rows, } template -struct UpdateToTensor { - framework::Tensor operator()(const platform::GPUDeviceContext& context, - const ScatterOps& op, - const framework::SelectedRows& input1, - framework::Tensor* input2) { +struct UpdateToTensor { + void operator()(const platform::CUDADeviceContext& context, + const ScatterOps& op, const framework::SelectedRows& input1, + framework::Tensor* input2) { // NOTE: Use SelectedRowsAddToTensor for better performance // no additional MergeAdd called. - auto merged_in1 = MergeAdd()(context, input1); + MergeAdd merge_func; + auto merged_in1 = merge_func(context, input1); auto in1_height = merged_in1.height(); auto in2_dims = input2->dims(); @@ -362,14 +364,14 @@ struct UpdateToTensor { int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); - auto* in1_data = in1_value.data(); - auto* input2_data = input2->data(); + auto* in1_data = in1_value.template data(); + auto* in2_data = input2->data(); - dim3 threads(PADDLE_CUDA_NUM_THREADS, 1); + dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1); dim3 grid(1, in1_rows.size()); - UpdateToTensorKernel< - T, PADDLE_CUDA_NUM_THREADS><<>>( - in1_data, in1_rows.data(), op, in2_data, in1_row_numel); + UpdateToTensorKernel<<< + grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op, + in2_data, in1_row_numel); } }; } // namespace scatter diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index d4bef72980f8374c1519c1774b72b67d67397a2c..09d4631905f90f78772368ad71b11826877bdc34 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -123,10 +123,9 @@ enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; // out = seleted_rows_in / tensor template struct UpdateToTensor { - framework::Tensor operator()(const DeviceContext& context, - const ScatterOps& op, - const framework::SelectedRows& input1, - framework::Tensor* input2); + void operator()(const DeviceContext& context, const ScatterOps& op, + const framework::SelectedRows& input1, + framework::Tensor* input2); }; } // namespace scatter