提交 46e14bbc 编写于 作者: M mozga-intel

Enforce: 2 and 4 dims, remove information about out in format

上级 32f8ac7d
...@@ -125,10 +125,10 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -125,10 +125,10 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto input = ctx.Input<Tensor>("Input"); auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W"); auto w = ctx.Input<Tensor>("W");
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 2, PADDLE_ENFORCE(input->dims().size() == 2 || input->dims().size() == 4,
"Input must be with 2 or 4 dimensions, i.e. NCHW"); "Input must be with 2 or 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(w->dims().size() == 2, PADDLE_ENFORCE(w->dims().size() == 2 || w->dims().size() == 4,
"Weights must be with 2 dimensions, i.e. NC"); "Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
bool with_bias = ctx.Attr<bool>("bias_attr"); bool with_bias = ctx.Attr<bool>("bias_attr");
MKLDNNMD<Tensor> md(input, w, with_bias); MKLDNNMD<Tensor> md(input, w, with_bias);
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/fc_op.h"
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,11 +30,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -29,11 +30,11 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
auto w_dims = ctx->GetInputDim("W"); auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]}); std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 2, PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor."); "Fully Connected input should be 2-D or 4-D tensor.");
PADDLE_ENFORCE(w_dims.size() == 2, PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4,
"Fully Connected input should be 2-D tensor."); "Fully Connected input should be 2-D or 4-D tensor.");
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Out"); ctx->ShareLoD("Input", "Out");
...@@ -73,19 +74,9 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -73,19 +74,9 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker) FCOpMaker::FCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
"Input",
"(Tensor) The input tensor of fully connected operator. "
"The format of input tensor is NCHW, where N is batch size, C is the "
"number of channels, H is the height of the feature, "
"and W is the width of the feature.");
AddInput("W", "(Tensor), The second input tensor of fc op."); AddInput("W", "(Tensor), The second input tensor of fc op.");
AddOutput("Out", AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
"(Tensor) The output tensor of fully connected operator. "
"The format of output tensor is also NCHW, "
"where N is batch size, C is the number of channels, "
"H is the height of the feature, "
"and W is the width of the feature.");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册