提交 c4effc7d 编写于 作者: Y Yu Yang

Fix CI Test

上级 e119177a
...@@ -33,7 +33,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp( ...@@ -33,7 +33,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
op_desc.SetType(op.Type()); op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs()); op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type()); auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.grad_op_maker_(op_desc); auto grad_descs = info.GradOpMaker()(op_desc);
std::vector<std::unique_ptr<OperatorBase>> grad_ops; std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size()); grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(), std::transform(grad_descs.begin(), grad_descs.end(),
...@@ -49,6 +49,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp( ...@@ -49,6 +49,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
for (auto& grad_op : grad_ops) { for (auto& grad_op : grad_ops) {
net_op->AppendOp(std::move(grad_op)); net_op->AppendOp(std::move(grad_op));
} }
net_op->CompleteAddOp();
return std::unique_ptr<OperatorBase>(net_op); return std::unique_ptr<OperatorBase>(net_op);
} }
} }
......
...@@ -30,7 +30,6 @@ namespace framework { ...@@ -30,7 +30,6 @@ namespace framework {
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
std::string grad_op_type_;
GradOpMakerFN grad_op_maker_; GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr}; OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
...@@ -51,6 +50,12 @@ struct OpInfo { ...@@ -51,6 +50,12 @@ struct OpInfo {
"Operator Creator has not been registered"); "Operator Creator has not been registered");
return creator_; return creator_;
} }
const GradOpMakerFN& GradOpMaker() const {
PADDLE_ENFORCE_NOT_NULL(grad_op_maker_,
"Operator GradOpMaker has not been registered.");
return grad_op_maker_;
}
}; };
class OpInfoMap { class OpInfoMap {
......
...@@ -137,23 +137,21 @@ class OpKernelRegistrar : public Registrar { ...@@ -137,23 +137,21 @@ class OpKernelRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \ __test_global_namespace_##uniq_name##__>::value, \
msg) msg)
#define VA_ARGS(...) , ##__VA_ARGS__ #define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
#define REGISTER_OPERATOR(op_type, op_class, ...) \ __reg_op__##op_type, \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ "REGISTER_OPERATOR must be called in global namespace"); \
__reg_op__##op_type, \ class _OpClass_##op_type##_ : public op_class { \
"REGISTER_OPERATOR must be called in global namespace"); \ public: \
class _OpClass_##op_type##_ : public op_class { \ DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
public: \ DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ }; \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \
}; \ ##__VA_ARGS__> \
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_ VA_ARGS( \ __op_registrar_##op_type##__(#op_type); \
__VA_ARGS__)> \ int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__(#op_type); \ __op_registrar_##op_type##__.Touch(); \
int TouchOpRegistrar_##op_type() { \ return 0; \
__op_registrar_##op_type##__.Touch(); \
return 0; \
} }
/** /**
...@@ -170,7 +168,7 @@ class OpKernelRegistrar : public Registrar { ...@@ -170,7 +168,7 @@ class OpKernelRegistrar : public Registrar {
virtual std::string GradOpType() const { return #grad_op_type; } \ virtual std::string GradOpType() const { return #grad_op_type; } \
}; \ }; \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
op_maker_class) op_maker_class);
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OPERATOR(op_type, op_class, op_maker_class) REGISTER_OPERATOR(op_type, op_class, op_maker_class)
......
...@@ -72,19 +72,26 @@ class MinusGradMaker : public framework::GradOpDescMakerBase { ...@@ -72,19 +72,26 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
std::vector<std::unique_ptr<framework::OpDescBind>> operator()() std::vector<std::unique_ptr<framework::OpDescBind>> operator()()
const override { const override {
std::vector<std::unique_ptr<framework::OpDescBind>> ops; std::vector<std::unique_ptr<framework::OpDescBind>> ops;
ops.resize(2); auto x_g = InputGrad("X");
if (!x_g.empty()) {
ops[0].reset(new framework::OpDescBind()); auto *x_g_op = new framework::OpDescBind();
ops[0]->SetType("scale"); x_g_op->SetType("scale");
ops[0]->SetInput("X", OutputGrad("Out")); x_g_op->SetInput("X", OutputGrad("Out"));
ops[0]->SetOutput("Out", InputGrad("X")); x_g_op->SetOutput("Out", x_g);
ops[0]->SetAttr("scale", 1.0f); x_g_op->SetAttr("scale", 1.0f);
ops.emplace_back(x_g_op);
ops[1].reset(new framework::OpDescBind()); }
ops[1]->SetType("scale");
ops[1]->SetInput("X", OutputGrad("Out")); auto y_g = InputGrad("Y");
ops[1]->SetOutput("Out", InputGrad("Y")); if (!y_g.empty()) {
ops[1]->SetAttr("scale", -1.0f); auto *y_g_op = new framework::OpDescBind();
y_g_op->SetType("scale");
y_g_op->SetInput("X", OutputGrad("Out"));
y_g_op->SetOutput("Out", y_g);
y_g_op->SetAttr("scale", -1.0f);
ops.emplace_back(y_g_op);
}
return ops; return ops;
} }
}; };
......
...@@ -121,6 +121,7 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -121,6 +121,7 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker {
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs()); bind->SetAttrMap(Attrs());
bind->SetType("pad_grad");
return std::unique_ptr<framework::OpDescBind>(bind); return std::unique_ptr<framework::OpDescBind>(bind);
} }
}; };
......
...@@ -14,6 +14,12 @@ ...@@ -14,6 +14,12 @@
#include "paddle/operators/softmax_with_cross_entropy_op.h" #include "paddle/operators/softmax_with_cross_entropy_op.h"
#include <paddle/function/TensorType.h> #include <paddle/function/TensorType.h>
#include <iostream>
#define DBG_LINE() \
do { \
std::cerr << "Run at " << __LINE__ << std::endl; \
} while (false)
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -187,8 +193,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { ...@@ -187,8 +193,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp, REGISTER_OPERATOR(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxWithCrossEntropyOpMaker, ops::SoftmaxGradMaker);
ops::SoftmaxWithCrossEntropyOpMaker);
REGISTER_OPERATOR(softmax_with_cross_entropy_grad, REGISTER_OPERATOR(softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyOpGrad); ops::SoftmaxWithCrossEntropyOpGrad);
REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册