提交 46e14bbc 编写于 作者: M mozga-intel

Enforce: 2 and 4 dims, remove information about out in format

上级 32f8ac7d
......@@ -125,10 +125,10 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W");
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 2,
PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
"Input must be with 2 or 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(w->dims().size() == 2,
"Weights must be with 2 dimensions, i.e. NC");
PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
bool with_bias = ctx.Attr<bool>("bias_attr");
MKLDNNMD<Tensor> md(input, w, with_bias);
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
#include <vector>
namespace paddle {
namespace operators {
......@@ -29,11 +30,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 2,
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");
PADDLE_ENFORCE(w_dims.size() == 2,
"Fully Connected input should be 2-D tensor.");
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Out");
......@@ -73,19 +74,9 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"Input",
"(Tensor) The input tensor of fully connected operator. "
"The format of input tensor is NCHW, where N is batch size, C is the "
"number of channels, H is the height of the feature, "
"and W is the width of the feature.");
AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op.");
AddOutput("Out",
"(Tensor) The output tensor of fully connected operator. "
"The format of output tensor is also NCHW, "
"where N is batch size, C is the number of channels, "
"H is the height of the feature, "
"and W is the width of the feature.");
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册