提交 a29b4227 编写于 作者: S sneaxiy

fix sparse gradient clip

上级 b6f61faf
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
......@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int64_t numel = x->numel();
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x_data,
x_data + numel, out_data, ClipFunctor<T>(min, max));
auto* x_var = context.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
auto* x = context.Input<framework::LoDTensor>("X");
auto* out = context.Output<framework::LoDTensor>("Out");
T* out_data = out->mutable_data<T>(context.GetPlace());
const T* x_data = x->data<T>();
int64_t numel = x->numel();
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x_data,
x_data + numel, out_data, ClipFunctor<T>(min, max));
} else if (x_var->IsType<framework::SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X");
auto* out = context.Output<framework::SelectedRows>("Out");
PADDLE_ENFORCE_NE(x, out,
"Inplace clip is not allowed when x is SelectedRows");
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *x, out);
auto* out_tensor = out->mutable_value();
auto* out_data = out_tensor->data<T>();
int64_t numel = out_tensor->numel();
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), out_data,
out_data + numel, out_data, ClipFunctor<T>(min, max));
} else {
PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows");
}
}
};
......@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<Tensor>("X");
auto* x = context.Input<framework::LoDTensor>("X");
int64_t numel = d_out->numel();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>();
......
......@@ -236,7 +236,7 @@ template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
T* out, const int64_t* out_rows,
size_t out_rows_size, int64_t row_numel) {
const int ty = blockIdx.y;
const int ty = blockIdx.x;
int tid = threadIdx.x;
__shared__ size_t out_idx;
......@@ -291,12 +291,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid1(1, input_rows.size());
dim3 grid1(input_rows.size(), 1);
MergeAddKernel<
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
input_data, input_rows.CUDAData(context.GetPlace()), out_data,
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
out.rows().size(), input_width);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册