提交 4b5986bb 编写于 作者: T tensor-tang

enable fc op in normal case

上级 e133df60
...@@ -295,12 +295,6 @@ op_library(channel_recv_op DEPS concurrency) ...@@ -295,12 +295,6 @@ op_library(channel_recv_op DEPS concurrency)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
# The fully connected layer is deleted when the WITH_MKLDNN flag is OFF
# Because the fully connected layer has only one MKLDNN's operator
if(NOT WITH_MKLDNN)
list(REMOVE_ITEM GENERAL_OPS fc_op)
endif(NOT WITH_MKLDNN)
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
op_library(${src}) op_library(${src})
endforeach() endforeach()
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fc_op.h" #include "paddle/fluid/operators/fc_op.h"
#include <vector> #include <vector>
#include "paddle/fluid/operators/math/blas.h"
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
...@@ -127,13 +128,13 @@ class FCOpKernel : public framework::OpKernel<T> { ...@@ -127,13 +128,13 @@ class FCOpKernel : public framework::OpKernel<T> {
"It must use CPUPlace."); "It must use CPUPlace.");
auto input = ctx.Input<Tensor>("Input"); auto input = ctx.Input<Tensor>("Input");
auto w = ctx.Input<Tensor>("W"); auto w = ctx.Input<Tensor>("W");
auto b = ctx.Input<Tensor>("Bias"); auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<Tensor>("Out"); auto output = ctx.Output<Tensor>("Out");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = input->dims();
auto w_dims = ctx->GetInputDim("W"); auto w_dims = w->dims();
auto& dev_ctx = ctx.template device_context<CPUDeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<CPUDeviceContext, T>(dev_ctx); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(dev_ctx);
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* w_data = w->data<T>(); const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
...@@ -147,7 +148,7 @@ class FCOpKernel : public framework::OpKernel<T> { ...@@ -147,7 +148,7 @@ class FCOpKernel : public framework::OpKernel<T> {
#pragma omp parallel for if (FLAGS_paddle_num_threads > 1) #pragma omp parallel for if (FLAGS_paddle_num_threads > 1)
for (int bs = 0; bs < in_dims[0]; bs++) { for (int bs = 0; bs < in_dims[0]; bs++) {
blas.AXPY(w_dims[1], static_cast<T>(1), bias_data, blas.AXPY(w_dims[1], static_cast<T>(1), bias_data,
output_data + bs * w_dimws[1]); output_data + bs * w_dims[1]);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册