未验证 提交 d6582449 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix some grad op desc maker (#16581)

test=develop
上级 423bc515
...@@ -3,7 +3,6 @@ acos ...@@ -3,7 +3,6 @@ acos
asin asin
atan atan
attention_lstm attention_lstm
bilinear_tensor_product
brelu brelu
conv_shift conv_shift
cos cos
...@@ -43,7 +42,6 @@ max_pool3d_with_index ...@@ -43,7 +42,6 @@ max_pool3d_with_index
maxout maxout
modified_huber_loss modified_huber_loss
nce nce
norm
pool2d pool2d
pool3d pool3d
pow pow
...@@ -60,16 +58,7 @@ reshape ...@@ -60,16 +58,7 @@ reshape
rnn_memory_helper rnn_memory_helper
round round
row_conv row_conv
sequence_concat
sequence_conv
sequence_expand
sequence_expand_as
sequence_pad
sequence_scatter
sequence_slice
sequence_softmax sequence_softmax
sequence_unpad
sigmoid_cross_entropy_with_logits
sin sin
softplus softplus
softshrink softshrink
...@@ -84,7 +73,6 @@ stanh ...@@ -84,7 +73,6 @@ stanh
swish swish
tanh_shrink tanh_shrink
teacher_student_sigmoid_loss teacher_student_sigmoid_loss
temporal_shift
tensor_array_to_tensor tensor_array_to_tensor
thresholded_relu thresholded_relu
transpose transpose
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/affine_grid_op.h" #include "paddle/fluid/operators/affine_grid_op.h"
#include <memory>
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
...@@ -173,9 +175,10 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -173,9 +175,10 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto theta_dims = ctx->GetInputDim("Theta");
if (ctx->HasOutput(framework::GradVarName("Theta"))) { if (ctx->HasOutput(framework::GradVarName("Theta"))) {
ctx->SetOutputDim(framework::GradVarName("Theta"), theta_dims); auto output_dims = ctx->GetInputDim(framework::GradVarName("Output"));
ctx->SetOutputDim(framework::GradVarName("Theta"),
{output_dims[0], 2, 3});
} }
} }
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/bilinear_tensor_product_op.h" #include "paddle/fluid/operators/bilinear_tensor_product_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -121,15 +124,9 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { ...@@ -121,15 +124,9 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
"The second dimension of input(Out@GRAD) must be equal to " "The second dimension of input(Out@GRAD) must be equal to "
"the third dimension of the Input(Weight)."); "the third dimension of the Input(Weight).");
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(
bias_dims[1], out_dims[1],
"The second dimension of input(Out@GRAD) must be equal to "
"the second dimension of the Input(Bias).");
auto bias_grad_name = framework::GradVarName("Bias"); auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name)) if (ctx->HasOutput(bias_grad_name)) {
ctx->SetOutputDim(bias_grad_name, bias_dims); ctx->SetOutputDim(bias_grad_name, {1, out_dims[1]});
} }
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
...@@ -148,13 +145,39 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel { ...@@ -148,13 +145,39 @@ class BilinearTensorProductOpGrad : public framework::OperatorWithKernel {
} }
}; };
class BilinearTensorProductGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("bilinear_tensor_product_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput("Weight", Input("Weight"));
if (ForwardOp().Inputs().count("Bias") > 0) {
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
}
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(bilinear_tensor_product, ops::BilinearTensorProductOp, REGISTER_OPERATOR(bilinear_tensor_product, ops::BilinearTensorProductOp,
ops::BilinearTensorProductOpMaker, ops::BilinearTensorProductOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::BilinearTensorProductGradOpDescMaker);
REGISTER_OPERATOR(bilinear_tensor_product_grad, REGISTER_OPERATOR(bilinear_tensor_product_grad,
ops::BilinearTensorProductOpGrad); ops::BilinearTensorProductOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -107,8 +108,6 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -107,8 +108,6 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
// check input // check input
PADDLE_ENFORCE(ctx->HasInput("Y"), PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of GroupNormOp should not be null."); "Input(Y) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"), PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of GroupNormOp should not be null."); "Input(Variance) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
...@@ -159,7 +158,6 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker { ...@@ -159,7 +158,6 @@ class GroupNormGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput("Bias", Input("Bias")); op->SetInput("Bias", Input("Bias"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetInput("Y", Output("Y")); op->SetInput("Y", Output("Y"));
op->SetInput("Mean", Output("Mean"));
op->SetInput("Variance", Output("Variance")); op->SetInput("Variance", Output("Variance"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
......
...@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/norm_op.h" #include "paddle/fluid/operators/norm_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -74,6 +78,24 @@ class NormOpGrad : public framework::OperatorWithKernel { ...@@ -74,6 +78,24 @@ class NormOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
}; };
class NormOpGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("norm_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Norm", Output("Norm"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -81,7 +103,7 @@ namespace ops = paddle::operators; ...@@ -81,7 +103,7 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker, REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::NormOpGradOpDescMaker);
REGISTER_OPERATOR(norm_grad, ops::NormOpGrad); REGISTER_OPERATOR(norm_grad, ops::NormOpGrad);
REGISTER_OP_CPU_KERNEL(norm, ops::NormKernel<CPU, float>, REGISTER_OP_CPU_KERNEL(norm, ops::NormKernel<CPU, float>,
ops::NormKernel<CPU, double>); ops::NormKernel<CPU, double>);
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -612,7 +615,8 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { ...@@ -612,7 +615,8 @@ class Pad2dOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
...@@ -625,7 +629,9 @@ class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -625,7 +629,9 @@ class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDesc(); auto* bind = new framework::OpDesc();
bind->SetInput("X", Input("X")); bind->SetInput("X", Input("X"));
if (ForwardOp().Inputs().count("Paddings") > 0) {
bind->SetInput("Paddings", Input("Paddings")); bind->SetInput("Paddings", Input("Paddings"));
}
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs()); bind->SetAttrMap(Attrs());
...@@ -634,6 +640,10 @@ class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -634,6 +640,10 @@ class Pad2dOpGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
// TODO(zjl): Paddings can also be skipped!
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(Pad2dOpGradNoNeedBufferVarsInference,
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -641,6 +651,7 @@ namespace ops = paddle::operators; ...@@ -641,6 +651,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker, REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker,
ops::Pad2dOpGradMaker); ops::Pad2dOpGradMaker);
REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad); REGISTER_OPERATOR(pad2d_grad, ops::Pad2dOpGrad,
ops::Pad2dOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>); REGISTER_OP_CPU_KERNEL(pad2d, ops::Pad2dCPUKernel<float>);
REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel<float>); REGISTER_OP_CPU_KERNEL(pad2d_grad, ops::Pad2dGradCPUKernel<float>);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/sequence_ops/sequence_concat_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_concat_op.h"
#include <memory>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -73,13 +74,43 @@ class SeqConcatShapeInferer : public framework::InferShapeBase { ...@@ -73,13 +74,43 @@ class SeqConcatShapeInferer : public framework::InferShapeBase {
} }
}; };
class SeqConcatGradShapeInferer : public framework::InferShapeBase { class SeqConcatGradOpDescMaker : public framework::SingleGradOpDescMaker {
public: public:
void operator()(framework::InferShapeContext *context) const override { using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_concat_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X", false));
op->SetAttrMap(Attrs());
return op;
}
};
class SeqConcatGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
context->SetOutputsDim(framework::GradVarName("X"), context->SetOutputsDim(framework::GradVarName("X"),
context->GetInputsDim("X")); context->GetInputsDim("X"));
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SeqConcatGradNoNeedBufferVarsInference,
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -87,14 +118,14 @@ namespace op = paddle::operators; ...@@ -87,14 +118,14 @@ namespace op = paddle::operators;
REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel, REGISTER_OPERATOR(sequence_concat, paddle::framework::OperatorWithKernel,
op::SeqConcatOpMaker, op::SeqConcatShapeInferer, op::SeqConcatOpMaker, op::SeqConcatShapeInferer,
paddle::framework::DefaultGradOpDescMaker<false>); op::SeqConcatGradOpDescMaker);
template <typename T> template <typename T>
using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>; using Kernel = op::SeqConcatKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>, REGISTER_OP_CPU_KERNEL(sequence_concat, Kernel<float>, Kernel<double>,
Kernel<int64_t>); Kernel<int64_t>);
REGISTER_OPERATOR(sequence_concat_grad, paddle::framework::OperatorWithKernel, REGISTER_OPERATOR(sequence_concat_grad, op::SeqConcatGradOp,
op::SeqConcatGradShapeInferer); op::SeqConcatGradNoNeedBufferVarsInference);
template <typename T> template <typename T>
using GradKernel = using GradKernel =
op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>; op::SeqConcatGradKernel<paddle::platform::CPUDeviceContext, T>;
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <utility>
#include <vector> #include <vector>
#include "boost/optional.hpp"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
...@@ -89,37 +91,49 @@ class SeqConcatGradKernel : public framework::OpKernel<T> { ...@@ -89,37 +91,49 @@ class SeqConcatGradKernel : public framework::OpKernel<T> {
dxs[i]->mutable_data<T>(context.GetPlace()); dxs[i]->mutable_data<T>(context.GetPlace());
} }
} }
std::vector<framework::Tensor> sliced_x; std::vector<framework::Tensor> sliced_x;
std::vector<boost::variant<boost::blank, framework::Tensor>> sliced_dx; std::vector<boost::optional<framework::Tensor>> sliced_dx;
for (size_t i = 1; i < xs[0]->lod()[0].size(); ++i) { for (size_t i = 1; i < xs[0]->lod()[0].size(); ++i) {
for (size_t j = 0; j < xs.size(); ++j) { for (size_t j = 0; j < xs.size(); ++j) {
const framework::LoDTensor *x = xs[j]; const framework::LoDTensor *x = xs[j];
framework::DDim x_dims = x->dims();
framework::LoDTensor *dx = dxs[j]; framework::LoDTensor *dx = dxs[j];
auto &x_lod = x->lod()[0]; auto &x_lod = x->lod()[0];
sliced_x.emplace_back(x->Slice(x_lod[i - 1], x_lod[i]));
if (dx != nullptr) { auto prev_lod = x_lod[i - 1];
sliced_dx.emplace_back(dx->Slice(x_lod[i - 1], x_lod[i])); auto next_lod = x_lod[i];
x_dims[0] = next_lod - prev_lod;
sliced_x.emplace_back();
sliced_x.back().Resize(x_dims);
if (dx) {
sliced_dx.emplace_back(dx->Slice(prev_lod, next_lod));
} else { } else {
sliced_dx.emplace_back(boost::blank()); sliced_dx.emplace_back(boost::none);
} }
} }
} }
math::SplitFunctor<DeviceContext, T> functor;
std::vector<const framework::Tensor *> sliced_x_ptr; std::vector<const framework::Tensor *> sliced_x_ptr;
std::vector<framework::Tensor *> sliced_dx_ptr; sliced_x_ptr.reserve(sliced_x.size());
for (auto &x : sliced_x) { for (auto &x : sliced_x) {
sliced_x_ptr.emplace_back(&x); sliced_x_ptr.emplace_back(&x);
} }
std::vector<framework::Tensor *> sliced_dx_ptr;
sliced_dx_ptr.reserve(sliced_dx.size());
for (auto &dx : sliced_dx) { for (auto &dx : sliced_dx) {
try { if (dx) {
sliced_dx_ptr.emplace_back(&boost::get<framework::Tensor>(dx)); sliced_dx_ptr.emplace_back(&dx.get());
} catch (boost::bad_get &) {
sliced_dx_ptr.emplace_back(nullptr);
} }
} }
math::SplitFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), functor(context.template device_context<DeviceContext>(),
detail::Ref( detail::Ref(
context.Input<framework::Tensor>(framework::GradVarName("Out")), context.Input<framework::Tensor>(framework::GradVarName("Out")),
......
...@@ -15,6 +15,9 @@ limitations under the License. */ ...@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_conv_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_conv_op.h"
#include <algorithm> #include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -171,13 +174,57 @@ context_length, context_stride and context_start. ...@@ -171,13 +174,57 @@ context_length, context_stride and context_start.
} }
}; };
class SequenceConvGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_conv_grad");
op->SetAttrMap(Attrs());
if (boost::get<bool>(Attrs().at("paddingTrainable")) &&
ForwardOp().Inputs().count("PaddingData") > 0) {
op->SetInput("PaddingData", Input("PaddingData"));
op->SetOutput(framework::GradVarName("PaddingData"),
InputGrad("PaddingData"));
}
op->SetInput("X", Input("X"));
op->SetInput("Filter", Input("Filter"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
return op;
}
};
class SequenceConvGradNoNeedBufferVarsInference
: public framework::NoNeedBufferVarsInference {
public:
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
std::unordered_set<std::string> operator()() const override {
if (!boost::get<bool>(Attrs().at("paddingTrainable"))) {
return {"PaddingData"};
} else {
return {};
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker, REGISTER_OPERATOR(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SequenceConvGradOpDescMaker);
REGISTER_OPERATOR(sequence_conv_grad, ops::SequenceConvGradOp);
REGISTER_OPERATOR(sequence_conv_grad, ops::SequenceConvGradOp,
ops::SequenceConvGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_conv, sequence_conv,
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h"
#include <memory>
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -70,6 +72,12 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { ...@@ -70,6 +72,12 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("Y", /*->*/ "Out"); ctx->ShareLoD("Y", /*->*/ "Out");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
}; };
class SequenceExpandAsOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceExpandAsOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -131,7 +139,6 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { ...@@ -131,7 +139,6 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
...@@ -143,16 +150,48 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { ...@@ -143,16 +150,48 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
ctx->ShareLoD("X", x_grad_name); ctx->ShareLoD("X", x_grad_name);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
}; };
class SequenceExpandAsOpGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_expand_as_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceExpandAsOpNoNeedBufferVarsInference, "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceExpandAsGradOpNoNeedBufferVarsInference, "X", "Y");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_expand_as, ops::SequenceExpandAsOp, REGISTER_OPERATOR(sequence_expand_as, ops::SequenceExpandAsOp,
ops::SequenceExpandAsOpMaker, ops::SequenceExpandAsOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SequenceExpandAsOpGradOpDescMaker,
REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad); ops::SequenceExpandAsOpNoNeedBufferVarsInference);
REGISTER_OPERATOR(sequence_expand_as_grad, ops::SequenceExpandAsOpGrad,
ops::SequenceExpandAsGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_expand_as, sequence_expand_as,
ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, float>, ops::SequenceExpandAsKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -96,6 +97,12 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -96,6 +97,12 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
}; };
class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker { class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -188,7 +195,6 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { ...@@ -188,7 +195,6 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
...@@ -199,16 +205,47 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { ...@@ -199,16 +205,47 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
}; };
class SequenceExpandOpGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_expand_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SequenceExpandOpNoNeedBufferVarsInference,
"Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceExpandGradOpNoNeedBufferVarsInference, "X", "Y");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_expand, ops::SequenceExpandOp, REGISTER_OPERATOR(sequence_expand, ops::SequenceExpandOp,
ops::SequenceExpandOpMaker, ops::SequenceExpandOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SequenceExpandOpGradDescMaker,
REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad); ops::SequenceExpandOpNoNeedBufferVarsInference);
REGISTER_OPERATOR(sequence_expand_grad, ops::SequenceExpandOpGrad,
ops::SequenceExpandGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_expand, sequence_expand,
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>, ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_pad_op.h"
#include <memory>
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -194,18 +196,39 @@ class SequencePadGradOp : public framework::OperatorWithKernel { ...@@ -194,18 +196,39 @@ class SequencePadGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
class SequencePadGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_pad_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequencePadGradOpNoNeedBufferVarsInference, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker, REGISTER_OPERATOR(sequence_pad, ops::SequencePadOp, ops::SequencePadOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SequencePadGradOpDescMaker);
REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp); REGISTER_OPERATOR(sequence_pad_grad, ops::SequencePadGradOp,
ops::SequencePadGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_pad, sequence_pad,
ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, float>, ops::SequencePadOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h"
#include <memory>
#include <string> #include <string>
namespace paddle { namespace paddle {
...@@ -114,7 +115,8 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { ...@@ -114,7 +115,8 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -138,13 +140,17 @@ class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker { ...@@ -138,13 +140,17 @@ class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequencePoolGradOpNoNeedBufferVarsInference, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker, REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
ops::SequencePoolGradOpMaker); ops::SequencePoolGradOpMaker);
REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp); REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp,
ops::SequencePoolGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_pool, sequence_pool,
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>); ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_scatter_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_scatter_op.h"
#include <memory>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
...@@ -124,25 +125,49 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { ...@@ -124,25 +125,49 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Updates"), ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates")); ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
class SequenceScatterGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_scatter_grad");
op->SetInput("Ids", Input("Ids"));
op->SetInput("Updates", Input("Updates"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Updates"), InputGrad("Updates"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceScatterGradNoNeedBufferVarsInference, "Updates");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp, REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp,
ops::SequenceScatterOpMaker, ops::SequenceScatterOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SequenceScatterGradDescMaker);
REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp); REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp,
ops::SequenceScatterGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel<float>, REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel<float>,
ops::SequenceScatterOpKernel<double>, ops::SequenceScatterOpKernel<double>,
ops::SequenceScatterOpKernel<int>, ops::SequenceScatterOpKernel<int>,
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_slice_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -70,7 +71,8 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { ...@@ -70,7 +71,8 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -113,14 +115,35 @@ NOTE: The first dimension size of input, the size of offset and Length, should b ...@@ -113,14 +115,35 @@ NOTE: The first dimension size of input, the size of offset and Length, should b
} }
}; };
class SequenceSliceGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_slice_grad");
op->SetInput("X", Input("X"));
op->SetInput("Offset", Input("Offset"));
op->SetInput("Length", Input("Length"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceSliceGradNoNeedBufferVarsInference, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_slice, ops::SequenceSliceOp, REGISTER_OPERATOR(sequence_slice, ops::SequenceSliceOp,
ops::SequenceSliceOpMaker, ops::SequenceSliceOpMaker, ops::SequenceSliceGradOpDescMaker);
paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp,
REGISTER_OPERATOR(sequence_slice_grad, ops::SequenceSliceGradOp); ops::SequenceSliceGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_slice, sequence_slice,
ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, float>); ops::SequenceSliceOpKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h" #include "paddle/fluid/operators/sequence_ops/sequence_unpad_op.h"
#include <memory>
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -125,19 +127,39 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel { ...@@ -125,19 +127,39 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
class SequenceUnpadGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sequence_unpad_grad");
op->SetAttrMap(Attrs());
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(
SequenceUnpadGradOpNoNeedBufferVarsInference, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp, REGISTER_OPERATOR(sequence_unpad, ops::SequenceUnpadOp,
ops::SequenceUnpadOpMaker, ops::SequenceUnpadOpMaker, ops::SequenceUnpadGradOpDescMaker);
paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp,
REGISTER_OPERATOR(sequence_unpad_grad, ops::SequenceUnpadGradOp); ops::SequenceUnpadGradOpNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_unpad, sequence_unpad,
ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, float>, ops::SequenceUnpadOpKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -81,10 +81,9 @@ class SequenceUnpadGradOpKernel : public framework::OpKernel<T> { ...@@ -81,10 +81,9 @@ class SequenceUnpadGradOpKernel : public framework::OpKernel<T> {
auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto* d_x = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (d_x) { if (d_x) {
const auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out")); const auto* d_out = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
const auto* x_t = ctx.Input<LoDTensor>("X");
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
int padded_length = x_t->dims()[1]; int padded_length = d_x->dims()[1];
LoDTensor zero_pads; LoDTensor zero_pads;
zero_pads.Resize({1, 1}); zero_pads.Resize({1, 1});
......
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/operators/shuffle_channel_op.h" #include "paddle/fluid/operators/shuffle_channel_op.h"
#include <memory> #include <memory>
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -73,12 +74,7 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { ...@@ -73,12 +74,7 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), auto input_dims = ctx->GetInputDim(framework::GradVarName("Out"));
"Input(Out@Grad) should not be null");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@Grad) should not be null");
auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
ctx->SetOutputDim(framework::GradVarName("X"), input_dims); ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
...@@ -87,7 +83,8 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { ...@@ -87,7 +83,8 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context()); ctx.device_context());
} }
}; };
...@@ -100,7 +97,6 @@ class ShuffleChannelGradDescMaker : public framework::SingleGradOpDescMaker { ...@@ -100,7 +97,6 @@ class ShuffleChannelGradDescMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc()); std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("shuffle_channel_grad"); op->SetType("shuffle_channel_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
......
...@@ -78,10 +78,14 @@ template <typename DeviceContext, typename T> ...@@ -78,10 +78,14 @@ template <typename DeviceContext, typename T>
class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> { class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X"); auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
int group = ctx.Attr<int>("group"); int group = ctx.Attr<int>("group");
auto input_dims = input->dims(); const auto& input_dims = input_grad->dims();
auto num = input_dims[0]; auto num = input_dims[0];
auto channel = input_dims[1]; auto channel = input_dims[1];
auto height = input_dims[2]; auto height = input_dims[2];
...@@ -91,10 +95,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -91,10 +95,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
int group_row = group; int group_row = group;
int group_column = channel / group_row; int group_column = channel / group_row;
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
......
...@@ -57,10 +57,14 @@ template <typename DeviceContext, typename T> ...@@ -57,10 +57,14 @@ template <typename DeviceContext, typename T>
class ShuffleChannelGradOpKernel : public framework::OpKernel<T> { class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X"); auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
int group = ctx.Attr<int>("group"); int group = ctx.Attr<int>("group");
auto input_dims = input->dims(); const auto& input_dims = input_grad->dims();
auto num = input_dims[0]; auto num = input_dims[0];
auto channel = input_dims[1]; auto channel = input_dims[1];
auto height = input_dims[2]; auto height = input_dims[2];
...@@ -71,10 +75,6 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> { ...@@ -71,10 +75,6 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
int group_row = group; int group_row = group;
int group_column = channel / group_row; int group_column = channel / group_row;
auto* output_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -139,6 +142,24 @@ However the output only shares the LoD with input `X`. ...@@ -139,6 +142,24 @@ However the output only shares the LoD with input `X`.
} }
}; };
class SigmoidCrossEntropyWithLogitsGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("sigmoid_cross_entropy_with_logits_grad");
op->SetInput("X", Input("X"));
op->SetInput("Label", Input("Label"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -146,7 +167,7 @@ namespace ops = paddle::operators; ...@@ -146,7 +167,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits, REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits,
ops::SigmoidCrossEntropyWithLogitsOp, ops::SigmoidCrossEntropyWithLogitsOp,
ops::SigmoidCrossEntropyWithLogitsOpMaker, ops::SigmoidCrossEntropyWithLogitsOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::SigmoidCrossEntropyWithLogitsGradOpDescMaker);
REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad, REGISTER_OPERATOR(sigmoid_cross_entropy_with_logits_grad,
ops::SigmoidCrossEntropyWithLogitsGradOp); ops::SigmoidCrossEntropyWithLogitsGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/slice_op.h" #include "paddle/fluid/operators/slice_op.h"
#include <algorithm> #include <algorithm>
#include <memory>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -135,6 +136,13 @@ class SliceOpGrad : public framework::OperatorWithKernel { ...@@ -135,6 +136,13 @@ class SliceOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
}; };
class SliceOpGradMaker : public framework::SingleGradOpDescMaker { class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
...@@ -153,13 +161,17 @@ class SliceOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -153,13 +161,17 @@ class SliceOpGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SliceOpGradNoNeedBufferVarsInference,
"Input");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
ops::SliceOpGradMaker); ops::SliceOpGradMaker);
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad); REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad,
ops::SliceOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>, slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
......
...@@ -10,6 +10,9 @@ ...@@ -10,6 +10,9 @@
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/temporal_shift_op.h" #include "paddle/fluid/operators/temporal_shift_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
...@@ -125,29 +128,41 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { ...@@ -125,29 +128,41 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x); ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
} }
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
class TemporalShiftGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("temporal_shift_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp, REGISTER_OPERATOR(temporal_shift, ops::TemporalShiftOp,
ops::TemporalShiftOpMaker, ops::TemporalShiftOpMaker, ops::TemporalShiftGradOpDescMaker);
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad); REGISTER_OPERATOR(temporal_shift_grad, ops::TemporalShiftOpGrad);
REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>, REGISTER_OP_CPU_KERNEL(temporal_shift, ops::TemporalShiftKernel<float>,
ops::TemporalShiftKernel<double>); ops::TemporalShiftKernel<double>);
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import importlib
fluid.core._set_eager_deletion_mode(0.0, 1.0, True)
from test_elementwise_add_op import *
from test_elementwise_sub_op import *
from test_concat_op import *
from test_gather_op import *
from test_gaussian_random_batch_size_like_op import *
from test_lod_reset_op import *
from test_scatter_op import *
from test_mean_op import *
from test_slice_op import *
from test_linear_chain_crf_op import *
from test_bilinear_interp_op import *
from test_nearest_interp_op import *
from test_sequence_concat import *
from test_seq_conv import *
from test_seq_pool import *
from test_sequence_expand_as import *
from test_sequence_expand import *
from test_sequence_pad_op import *
from test_sequence_unpad_op import *
from test_sequence_scatter_op import *
from test_sequence_slice_op import *
from test_pad2d_op import *
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册