提交 6246be29 编写于 作者: C chengduoZH

clean gradient data

上级 4d112b7d
...@@ -71,6 +71,8 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel { ...@@ -71,6 +71,8 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null."); "Gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Gradient of input(X@GRAD) should not be null.");
if (ctx->Attrs().Get<bool>("padding_trainable")) { if (ctx->Attrs().Get<bool>("padding_trainable")) {
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("PaddingData")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("PaddingData")),
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h" #include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/strided_memcpy.h" #include "paddle/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
...@@ -177,6 +178,10 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -177,6 +178,10 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X")); auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
in_g->mutable_data<T>(context.GetPlace()); in_g->mutable_data<T>(context.GetPlace());
if (in_g) {
math::SetConstant<Place, T> functor;
functor(context.device_context(), in_g, 0);
}
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
int context_start = context.Attr<int>("context_start"); int context_start = context.Attr<int>("context_start");
...@@ -204,6 +209,8 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -204,6 +209,8 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
padding_width = padding_data_g->dims()[1]; padding_width = padding_data_g->dims()[1];
PADDLE_ENFORCE(padding_width == input_width, PADDLE_ENFORCE(padding_width == input_width,
"Input size and pooling size should be consistent."); "Input size and pooling size should be consistent.");
math::SetConstant<Place, T> functor;
functor(context.device_context(), padding_data_g, 0);
} }
int up_pad = std::max(0, -context_start); int up_pad = std::max(0, -context_start);
...@@ -282,7 +289,7 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> { ...@@ -282,7 +289,7 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
} }
} }
if (in && input_row_begin < input_row_end) { if (in_g && input_row_begin < input_row_end) {
Tensor in_t = in_g->Slice(input_row_begin, input_row_end); Tensor in_t = in_g->Slice(input_row_begin, input_row_end);
std::vector<int64_t> output_shape( std::vector<int64_t> output_shape(
......
...@@ -87,9 +87,9 @@ class TestSeqProject(OpTest): ...@@ -87,9 +87,9 @@ class TestSeqProject(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
# def test_check_grad(self): def test_check_grad(self):
# self.check_grad( self.check_grad(
# set(['X', 'PaddingData']), 'Out', max_relative_error=0.05) set(['X', 'PaddingData']), 'Out', max_relative_error=0.05)
# def test_check_grad_no_filter(self): # def test_check_grad_no_filter(self):
# self.check_grad( # self.check_grad(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册