提交 4b5986bb 编写于 作者: T tensor-tang

enable fc op in normal case

上级 e133df60
......@@ -295,12 +295,6 @@ op_library(channel_recv_op DEPS concurrency)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
# The fully connected layer is deleted when the WITH_MKLDNN flag is OFF
# Because the fully connected layer has only one MKLDNN's operator
if(NOT WITH_MKLDNN)
list(REMOVE_ITEM GENERAL_OPS fc_op)
endif(NOT WITH_MKLDNN)
foreach(src ${GENERAL_OPS})
op_library(${src})
endforeach()
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h"
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
DECLARE_int32(paddle_num_threads);
......@@ -127,13 +128,13 @@ class FCOpKernel : public framework::OpKernel<T> {
"It must use CPUPlace.");
auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W");
auto b = ctx.Input<Tensor>("Bias");
auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<Tensor>("Out");
auto in_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("W");
auto in_dims = input->dims();
auto w_dims = w->dims();
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>();
auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx);
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
......@@ -147,7 +148,7 @@ class FCOpKernel : public framework::OpKernel<T> {
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
for (int bs = 0; bs < in_dims[0]; bs++) {
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
output_data + bs * w_dimws[1]);
output_data + bs * w_dims[1]);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册