提交 1039c1e3 编写于 作者: T typhoonzero

scatter optimizers

上级 641b4c0f
...@@ -79,12 +79,12 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> { ...@@ -79,12 +79,12 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
framework::Tensor* moment, framework::Tensor* param) { framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows) // 1. g_m.rows = set(g.rows)
auto grad_width = grad.value().dims()[1]; auto grad_width = grad.value().dims()[1];
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func; math::scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
auto grad_merge = merge_func(context, grad); auto grad_merge = merge_func(context, grad);
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>(); auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
auto& merge_rows = grad_merge.rows; auto& merge_rows = grad_merge.rows();
// 2. m += g_m * g_m // 2. m += g_m * g_m
math::scatter::Mul<platform::CPUDeviceContext, T> sqare_func; math::scatter::Mul<platform::CUDADeviceContext, T> sqare_func;
auto grad_square = sqare_func(context, grad_merge, grad_merge); auto grad_square = sqare_func(context, grad_merge, grad_merge);
math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor; math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor;
...@@ -95,11 +95,13 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> { ...@@ -95,11 +95,13 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
auto* param_data = param->data<T>(); auto* param_data = param->data<T>();
auto* moment_data = moment->data<T>(); auto* moment_data = moment->data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid2(1, merge_rows.size()); dim3 grid2(1, merge_rows.size());
SparseAdagradFunctorKernel< SparseAdagradFunctorKernel<
T, 256><<<grid2, threads, 0, T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, grad_merge->rows().data(), .stream()>>>(grad_merge_data, grad_merge.rows().data(),
lr, param_data, moment_data, grad_width, lr, param_data, moment_data, grad_width,
epsilon); epsilon);
} }
......
...@@ -233,9 +233,8 @@ template struct MergeAdd<platform::CPUDeviceContext, int64_t>; ...@@ -233,9 +233,8 @@ template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template <typename T> template <typename T>
struct UpdateToTensor<platform::CPUDeviceContext, T> { struct UpdateToTensor<platform::CPUDeviceContext, T> {
framework::Tensor operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const ScatterOps& op, const ScatterOps& op, const framework::SelectedRows& input1,
const framework::SelectedRows& input1,
framework::Tensor* input2) { framework::Tensor* input2) {
auto in1_height = input1.height(); auto in1_height = input1.height();
auto in2_dims = input2->dims(); auto in2_dims = input2->dims();
......
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <set>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h" #include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
...@@ -251,8 +253,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, ...@@ -251,8 +253,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
} }
template <typename T> template <typename T>
struct MergeAdd<platform::GPUDeviceContext, T> { struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::GPUDeviceContext& context, framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) { const framework::SelectedRows& input) {
framework::SelectedRows out; framework::SelectedRows out;
auto input_rows = input.rows(); auto input_rows = input.rows();
...@@ -288,10 +290,10 @@ struct MergeAdd<platform::GPUDeviceContext, T> { ...@@ -288,10 +290,10 @@ struct MergeAdd<platform::GPUDeviceContext, T> {
} }
}; };
template struct MergeAdd<platform::GPUDeviceContext, float>; template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::GPUDeviceContext, double>; template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::GPUDeviceContext, int>; template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::GPUDeviceContext, int64_t>; template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template <typename T, int block_size> template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows, __global__ void UpdateToTensorKernel(const T* selected_rows,
...@@ -343,14 +345,14 @@ __global__ void UpdateToTensorKernel(const T* selected_rows, ...@@ -343,14 +345,14 @@ __global__ void UpdateToTensorKernel(const T* selected_rows,
} }
template <typename T> template <typename T>
struct UpdateToTensor<platform::GPUDeviceContext, T> { struct UpdateToTensor<platform::CUDADeviceContext, T> {
framework::Tensor operator()(const platform::GPUDeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const ScatterOps& op, const ScatterOps& op, const framework::SelectedRows& input1,
const framework::SelectedRows& input1,
framework::Tensor* input2) { framework::Tensor* input2) {
// NOTE: Use SelectedRowsAddToTensor for better performance // NOTE: Use SelectedRowsAddToTensor for better performance
// no additional MergeAdd called. // no additional MergeAdd called.
auto merged_in1 = MergeAdd()(context, input1); MergeAdd<platform::CUDADeviceContext, T> merge_func;
auto merged_in1 = merge_func(context, input1);
auto in1_height = merged_in1.height(); auto in1_height = merged_in1.height();
auto in2_dims = input2->dims(); auto in2_dims = input2->dims();
...@@ -362,14 +364,14 @@ struct UpdateToTensor<platform::GPUDeviceContext, T> { ...@@ -362,14 +364,14 @@ struct UpdateToTensor<platform::GPUDeviceContext, T> {
int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>(); auto* in1_data = in1_value.template data<T>();
auto* input2_data = input2->data<T>(); auto* in2_data = input2->data<T>();
dim3 threads(PADDLE_CUDA_NUM_THREADS, 1); dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
dim3 grid(1, in1_rows.size()); dim3 grid(1, in1_rows.size());
UpdateToTensorKernel< UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
T, PADDLE_CUDA_NUM_THREADS><<<grid, threads, 0, context.stream()>>>( grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op,
in1_data, in1_rows.data(), op, in2_data, in1_row_numel); in2_data, in1_row_numel);
} }
}; };
} // namespace scatter } // namespace scatter
......
...@@ -123,8 +123,7 @@ enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; ...@@ -123,8 +123,7 @@ enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor // out = seleted_rows_in / tensor
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct UpdateToTensor { struct UpdateToTensor {
framework::Tensor operator()(const DeviceContext& context, void operator()(const DeviceContext& context, const ScatterOps& op,
const ScatterOps& op,
const framework::SelectedRows& input1, const framework::SelectedRows& input1,
framework::Tensor* input2); framework::Tensor* input2);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册