提交 038cbf79 编写于 作者: T tensor-tang

add bias for fc op

上级 49ad570e
...@@ -30,21 +30,34 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -30,21 +30,34 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
auto w_dims = ctx->GetInputDim("W"); auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]}); std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], framework::product(w_dims) / w_dims[0],
"The shape of Bias must be [1, dim].");
}
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor."); "Fully Connected input should be 2-D or 4-D tensor.");
PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4, PADDLE_ENFORCE(w_dims.size() == 2 || w_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor."); "Fully Connected input should be 2-D or 4-D tensor.");
PADDLE_ENFORCE_EQ(framework::product(w_dims) / w_dims[0],
framework::product(in_dims) / in_dims[0],
"Fully Connected input and weigth size do not match.");
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Out"); ctx->ShareLoD("Input", "Out");
} }
framework::OpKernelType FCOp::GetExpectedKernelType( framework::OpKernelType FCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kMKLDNN}; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout{framework::DataLayout::kMKLDNN}; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn");) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library); layout, library);
...@@ -60,13 +73,22 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -60,13 +73,22 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
if (ctx->HasOutput(framework::GradVarName("W"))) { if (ctx->HasOutput(framework::GradVarName("W"))) {
ctx->SetOutputDim(framework::GradVarName("W"), w_dims); ctx->SetOutputDim(framework::GradVarName("W"), w_dims);
} }
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias"));
ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
}
} }
framework::OpKernelType FCOpGrad::GetExpectedKernelType( framework::OpKernelType FCOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library{framework::LibraryType::kMKLDNN}; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout{framework::DataLayout::kMKLDNN}; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.Attr<bool>("use_mkldnn");) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout, library); layout, library);
...@@ -75,12 +97,12 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType( ...@@ -75,12 +97,12 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
void FCOpMaker::Make() { void FCOpMaker::Make() {
AddInput("Input", "(Tensor) The input tensor of fully connected operator. "); AddInput("Input", "(Tensor) The input tensor of fully connected operator. ");
AddInput("W", "(Tensor), The second input tensor of fc op."); AddInput("W", "(Tensor), The second input tensor of fc op.");
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x D")
.AsDispensable();
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. "); AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("bias_attr", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Fully Connected Operator. Fully Connected Operator.
...@@ -94,9 +116,39 @@ void FCOpMaker::Make() { ...@@ -94,9 +116,39 @@ void FCOpMaker::Make() {
)DOC"); )DOC");
} }
template <typename T>
class FCOpKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto b = ctx.Input<Tensor>("Bias");
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
auto output = ctx.Output<Tensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto in_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
if (bias) {
const T* bias_data = bias->data<T>();
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OPERATOR(fc, paddle::operators::FCOp, paddle::operators::FCOpMaker, namespace ops = paddle::operators;
REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(fc_grad, paddle::operators::FCOpGrad); REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
REGISTER_OP_CPU_KERNEL(fc, ops::FCMKLDNNOpKernel<float>,
ops::FCMKLDNNOpKernel<double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册