提交 2146293d 编写于 作者: H Huihuang Zheng 提交者: Zeng Jinle

Fix op registry (#16677)

list of fixed ops:
lookup_table_op
space_to_depth_op
squared_l2_distance_op
squared_l2_norm_op
teacher_student_sigmoid_loss_op
tree_conv_op
warpctc_op

test=develop
上级 5c364cda
......@@ -28,7 +28,6 @@ hierarchical_sigmoid
leaky_relu
log
logsigmoid
lookup_table
lrn
lstm_unit
lstmp
......@@ -57,20 +56,14 @@ sin
softplus
softshrink
softsign
space_to_depth
spp
square
squared_l2_distance
squared_l2_norm
squeeze
stanh
swish
tanh_shrink
teacher_student_sigmoid_loss
tensor_array_to_tensor
thresholded_relu
transpose
tree_conv
unpool
unsqueeze
warpctc
......@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lookup_table_op.h"
#include <memory>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
......@@ -119,6 +123,29 @@ or not. And the output only shares the LoD information with input Ids.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(LookupTableGradOpNoBuffer, "W");
class LookupTableGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("lookup_table_grad");
op->SetInput("W", Input("W"));
op->SetInput("Ids", Input("Ids"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("W"), InputGrad("W"));
op->SetAttrMap(Attrs());
return op;
}
};
class LookupTableOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -131,7 +158,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Out"));
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......@@ -159,10 +187,11 @@ class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lookup_table, ops::LookupTableOp,
paddle::framework::DefaultGradOpDescMaker<true>,
ops::LookupTableOpMaker);
REGISTER_OPERATOR(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
ops::LookupTableGradOpDescMaker);
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
ops::LookupTableGradOpNoBuffer,
ops::LookupTableOpGradVarTypeInference);
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
......
......@@ -13,12 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/space_to_depth_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class SpaceToDepthOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -100,6 +106,28 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SpaceToDepthGradOpNoBuffer, "X");
class SpaceToDepthGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("space_to_depth_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("X", Input("X"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
class SpaceToDepthGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -110,6 +138,14 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
......@@ -117,8 +153,9 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp);
ops::SpaceToDepthGradOpDescMaker);
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp,
ops::SpaceToDepthGradOpNoBuffer);
REGISTER_OP_CPU_KERNEL(
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -14,6 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/squared_l2_distance_op.h"
#include <memory>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
namespace paddle {
namespace operators {
......@@ -54,6 +58,34 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SquaredL2DistanceGradOpNoBuffer, "X",
"Y");
class SquaredL2DistanceGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("squared_l2_distance_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("sub_result", Output("sub_result"));
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetAttrMap(Attrs());
return op;
}
};
class SquaredL2DistanceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -88,6 +120,7 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null");
PADDLE_ENFORCE(ctx->HasInput("sub_result"), "SubResult should not be null");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
......@@ -102,6 +135,13 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims);
if (ctx->HasOutput(y_grad_name)) ctx->SetOutputDim(y_grad_name, y_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("sub_result")->type(),
ctx.GetPlace());
}
};
} // namespace operators
......@@ -110,8 +150,9 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(squared_l2_distance, ops::SquaredL2DistanceOp,
ops::SquaredL2DistanceOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(squared_l2_distance_grad, ops::SquaredL2DistanceGradOp);
ops::SquaredL2DistanceGradOpDescMaker);
REGISTER_OPERATOR(squared_l2_distance_grad, ops::SquaredL2DistanceGradOp,
ops::SquaredL2DistanceGradOpNoBuffer);
REGISTER_OP_CPU_KERNEL(
squared_l2_distance,
ops::SquaredL2DistanceKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/squared_l2_norm_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -31,6 +33,26 @@ class SquaredL2NormOp : public framework::OperatorWithKernel {
}
};
class SquaredL2NormGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("squared_l2_norm_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("X", Input("X"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
class SquaredL2NormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -67,8 +89,7 @@ $$Out = \sum_{i} X_{i}^2$$
namespace ops = paddle::operators;
REGISTER_OPERATOR(squared_l2_norm, ops::SquaredL2NormOp,
ops::SquaredL2NormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::SquaredL2NormOpMaker, ops::SquaredL2NormGradOpDescMaker);
REGISTER_OPERATOR(squared_l2_norm_grad, ops::SquaredL2NormGradOp);
REGISTER_OP_CPU_KERNEL(
squared_l2_norm,
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/teacher_student_sigmoid_loss_op.h"
#include <memory>
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
......@@ -55,6 +58,28 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
}
};
class TeacherStudentSigmoidLossGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("teacher_student_sigmoid_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("Label", Input("Label"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
class TeacherStudentSigmoidLossGradientOp
: public framework::OperatorWithKernel {
public:
......@@ -148,7 +173,7 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(teacher_student_sigmoid_loss,
ops::TeacherStudentSigmoidLossOp,
ops::TeacherStudentSigmoidLossOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::TeacherStudentSigmoidLossGradOpDescMaker);
REGISTER_OPERATOR(teacher_student_sigmoid_loss_grad,
ops::TeacherStudentSigmoidLossGradientOp);
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/fluid/operators/tree_conv_op.h"
#include <memory>
#include <string>
namespace paddle {
......@@ -86,6 +88,30 @@ class TreeConvOp : public framework::OperatorWithKernel {
}
};
class TreeConvGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("tree_conv_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Filter", Input("Filter"));
op->SetInput("EdgeSet", Input("EdgeSet"));
op->SetInput("NodesVector", Input("NodesVector"));
op->SetOutput(framework::GradVarName("NodesVector"),
InputGrad("NodesVector"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
op->SetAttrMap(Attrs());
return op;
}
};
class TreeConvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -115,7 +141,7 @@ class TreeConvGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(tree_conv, ops::TreeConvOp, ops::TreeConvOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::TreeConvGradOpDescMaker);
REGISTER_OPERATOR(tree_conv_grad, ops::TreeConvGradOp);
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/warpctc_op.h"
#include <memory>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
......@@ -118,6 +120,27 @@ http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf).
}
};
class WarpCTCGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("warpctc_grad");
op->SetInput("WarpCTCGrad", Output("WarpCTCGrad"));
op->SetInput("Logits", Input("Logits"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
op->SetAttrMap(Attrs());
return op;
}
};
class WarpCTCGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -145,7 +168,7 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::WarpCTCGradOpDescMaker);
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp);
REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -18,20 +18,21 @@ import importlib
fluid.core._set_eager_deletion_mode(0.0, 1.0, True)
from test_bilinear_interp_op import *
from test_concat_op import *
from test_elementwise_add_op import *
from test_elementwise_sub_op import *
from test_concat_op import *
from test_fill_constant_batch_size_like_op import *
from test_fill_zeros_like2_op import *
from test_gather_op import *
from test_gaussian_random_batch_size_like_op import *
from test_uniform_random_batch_size_like_op import *
from test_fill_constant_batch_size_like_op import *
from test_linear_chain_crf_op import *
from test_lod_reset_op import *
from test_scatter_op import *
from test_lookup_table_op import *
from test_mean_op import *
from test_slice_op import *
from test_linear_chain_crf_op import *
from test_bilinear_interp_op import *
from test_nearest_interp_op import *
from test_pad2d_op import *
from test_scatter_op import *
from test_sequence_concat import *
from test_seq_conv import *
from test_seq_pool import *
......@@ -41,8 +42,10 @@ from test_sequence_pad_op import *
from test_sequence_unpad_op import *
from test_sequence_scatter_op import *
from test_sequence_slice_op import *
from test_pad2d_op import *
from test_fill_zeros_like2_op import *
from test_slice_op import *
from test_space_to_depth_op import *
from test_squared_l2_distance_op import *
from test_uniform_random_batch_size_like_op import *
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册