提交 81a352af 编写于 作者: D dongzhihong

"test fc without gradient"

上级 14424f31
......@@ -29,4 +29,4 @@ add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
......@@ -22,8 +22,6 @@ namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
auto grad_ops = std::make_shared<PlainNet>();
// std::shared_ptr<PlainNet> grad_ops;
// grad_ops.reset(new PlainNet);
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
......
......@@ -6,6 +6,7 @@
USE_OP(add_two);
USE_OP(mul);
USE_OP(sigmoid);
USE_OP(softmax);
namespace paddle {
namespace framework {
......@@ -75,16 +76,21 @@ TEST(AddBackwardOp, TestGradOp) {
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
// net->AddOp(framework::OpRegistry::CreateOp("fc"), {
// Input("X"), Input("W"), Input("b")},
// {Output("Y")},
// {}
// );
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
} // namespace framework
} // namespace paddle
......@@ -286,7 +286,13 @@ class OpRegistry {
}
static OperatorPtr CreateGradOp(OperatorPtr op) {
OperatorPtr grad_op(grad_creators().at(op->type_)());
auto it = grad_creators().find(op->type_);
if (it == grad_creators().end()) {
LOG(INFO) << op->type_ << "does not has gradient op";
return nullptr;
}
// OperatorPtr grad_op(grad_creators().at(op->type_)());
OperatorPtr grad_op(it->second());
grad_op->type_ = op->type_;
AssembleGradInOut(op, grad_op);
......
......@@ -40,10 +40,23 @@ public:
}
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad";
return "";
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel<paddle::platform::CPUPlace>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册