提交 a29b4227 编写于 作者: S sneaxiy

fix sparse gradient clip

上级 b6f61faf
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
...@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min"); auto min = context.Attr<T>("min");
auto* x = context.Input<Tensor>("X"); auto* x_var = context.InputVar("X");
auto* out = context.Output<Tensor>("Out"); if (x_var->IsType<framework::LoDTensor>()) {
auto* x = context.Input<framework::LoDTensor>("X");
auto* out = context.Output<framework::LoDTensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
int64_t numel = x->numel(); int64_t numel = x->numel();
Transform<DeviceContext> trans; Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x_data, trans(context.template device_context<DeviceContext>(), x_data,
x_data + numel, out_data, ClipFunctor<T>(min, max)); x_data + numel, out_data, ClipFunctor<T>(min, max));
} else if (x_var->IsType<framework::SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X");
auto* out = context.Output<framework::SelectedRows>("Out");
PADDLE_ENFORCE_NE(x, out,
"Inplace clip is not allowed when x is SelectedRows");
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *x, out);
auto* out_tensor = out->mutable_value();
auto* out_data = out_tensor->data<T>();
int64_t numel = out_tensor->numel();
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), out_data,
out_data + numel, out_data, ClipFunctor<T>(min, max));
} else {
PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows");
}
} }
}; };
...@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> { ...@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min"); auto min = context.Attr<T>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out =
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) { if (d_x != nullptr) {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<framework::LoDTensor>("X");
int64_t numel = d_out->numel(); int64_t numel = d_out->numel();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace()); auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>(); const T* d_out_data = d_out->data<T>();
......
...@@ -236,7 +236,7 @@ template <typename T, int block_size> ...@@ -236,7 +236,7 @@ template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows, __global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
T* out, const int64_t* out_rows, T* out, const int64_t* out_rows,
size_t out_rows_size, int64_t row_numel) { size_t out_rows_size, int64_t row_numel) {
const int ty = blockIdx.y; const int ty = blockIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
__shared__ size_t out_idx; __shared__ size_t out_idx;
...@@ -291,12 +291,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> { ...@@ -291,12 +291,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
const int block_size = 256; const int block_size = 256;
dim3 threads(block_size, 1); dim3 threads(block_size, 1);
dim3 grid1(1, input_rows.size()); dim3 grid1(input_rows.size(), 1);
MergeAddKernel< MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
input_data, input_rows.CUDAData(context.GetPlace()), out_data, input_data, input_rows.CUDAData(context.GetPlace()), out_data,
out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width); out.rows().size(), input_width);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册