提交 dae249b1 编写于 作者: L Liu Yiqun

Delete USE_OP statements and add more ENFORCE statements to check the inputs and outputs in FCOp.

上级 6ce4bf36
...@@ -24,6 +24,15 @@ class FCOp : public NetOp { ...@@ -24,6 +24,15 @@ class FCOp : public NetOp {
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) { : NetOp(type, inputs, outputs, attrs) {
PADDLE_ENFORCE(!Inputs("X").empty(),
"Inputs(X) of FCOp should not be null.");
PADDLE_ENFORCE(!Inputs("W").empty(),
"Inputs(W) of FCOp should not be null.");
PADDLE_ENFORCE(!Outputs("MulOut").empty(),
"Outputs(MulOut) of FCOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
"Output(Out) of FCOp should not be null.");
auto x = Inputs("X"); auto x = Inputs("X");
auto w = Inputs("W"); auto w = Inputs("W");
auto mul_out = Outputs("MulOut"); auto mul_out = Outputs("MulOut");
...@@ -68,6 +77,10 @@ class FCOp : public NetOp { ...@@ -68,6 +77,10 @@ class FCOp : public NetOp {
// sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1] // sum_out = X[0] * W[0] + ... + X[n-1] * W[n-1]
auto sum_out = mul_out[0]; auto sum_out = mul_out[0];
if (n > 1) { if (n > 1) {
PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName,
"Output(SumOut) of FCOp should not be null when the "
"size of Inputs(X) > 1.");
sum_out = Output("SumOut"); sum_out = Output("SumOut");
AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}}, AppendOp(framework::OpRegistry::CreateOp("sum", {{"X", {mul_out}}},
{{"Out", {sum_out}}}, {})); {{"Out", {sum_out}}}, {}));
...@@ -81,6 +94,10 @@ class FCOp : public NetOp { ...@@ -81,6 +94,10 @@ class FCOp : public NetOp {
auto b = Input("B"); auto b = Input("B");
auto add_out = sum_out; auto add_out = sum_out;
if (b != framework::kEmptyVarName) { if (b != framework::kEmptyVarName) {
PADDLE_ENFORCE_NE(
Output("AddOut"), framework::kEmptyVarName,
"Output(AddOut) of FCOp should not be null when Input(B) is set.");
add_out = Output("AddOut"); add_out = Output("AddOut");
AppendOp(framework::OpRegistry::CreateOp( AppendOp(framework::OpRegistry::CreateOp(
"rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}}, "rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}},
...@@ -176,11 +193,5 @@ Activation type can be set to `identity` (default), `sigmoid` or `softmax`. ...@@ -176,11 +193,5 @@ Activation type can be set to `identity` (default), `sigmoid` or `softmax`.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
USE_OP(mul);
USE_OP(rowwise_add);
USE_NO_KERNEL_OP(identity);
USE_OP(sigmoid);
USE_OP(softmax);
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fc, ops::FCOp, ops::FCOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fc, ops::FCOp, ops::FCOpMaker);
...@@ -44,8 +44,8 @@ class IdentityOp : public NetOp { ...@@ -44,8 +44,8 @@ class IdentityOp : public NetOp {
: NetOp(type, inputs, outputs, attrs) { : NetOp(type, inputs, outputs, attrs) {
PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName,
"Input(X) of IdentityOp should not be null."); "Input(X) of IdentityOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, PADDLE_ENFORCE_NE(Output("Y"), framework::kEmptyVarName,
"Output(Out) of IdentityOp should not be null."); "Output(Y) of IdentityOp should not be null.");
AppendOp(framework::OpRegistry::CreateOp( AppendOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {Input("X")}}}, {{"Out", {Output("Y")}}}, "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Y")}}},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册