提交 6246be29 编写于 作者: C chengduoZH

clean gradient data

上级 4d112b7d
......@@ -71,6 +71,8 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Gradient of input(X@GRAD) should not be null.");
if (ctx->Attrs().Get<bool>("padding_trainable")) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("PaddingData")),
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
......@@ -177,6 +178,10 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* in = context.Input<LoDTensor>("X");
in_g->mutable_data<T>(context.GetPlace());
if (in_g) {
math::SetConstant<Place, T> functor;
functor(context.device_context(), in_g, 0);
}
auto place = context.GetEigenDevice<Place>();
int context_start = context.Attr<int>("context_start");
......@@ -204,6 +209,8 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
padding_width = padding_data_g->dims()[1];
PADDLE_ENFORCE(padding_width == input_width,
"Input size and pooling size should be consistent.");
math::SetConstant<Place, T> functor;
functor(context.device_context(), padding_data_g, 0);
}
int up_pad = std::max(0, -context_start);
......@@ -282,7 +289,7 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
}
}
if (in && input_row_begin < input_row_end) {
if (in_g && input_row_begin < input_row_end) {
Tensor in_t = in_g->Slice(input_row_begin, input_row_end);
std::vector<int64_t> output_shape(
......
......@@ -87,9 +87,9 @@ class TestSeqProject(OpTest):
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(
# set(['X', 'PaddingData']), 'Out', max_relative_error=0.05)
def test_check_grad(self):
self.check_grad(
set(['X', 'PaddingData']), 'Out', max_relative_error=0.05)
# def test_check_grad_no_filter(self):
# self.check_grad(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册