未验证 提交 c072998a 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #16219 from luotao1/fc_infershape

refine fc_infershape
...@@ -55,17 +55,8 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -55,17 +55,8 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
"The input tensor Input's rank of FCOp should be larger than " "The input tensor Input's rank of FCOp should be larger than "
"in_num_col_dims."); "in_num_col_dims.");
auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims);
PADDLE_ENFORCE_EQ(
in_mat_dims[1], w_dims[0],
"Fully Connected input and weigth size do not match. %s, %s");
std::vector<int64_t> output_dims; std::vector<int64_t> output_dims;
output_dims.reserve(static_cast<size_t>(in_num_col_dims + 1)); FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims);
for (int i = 0; i < in_num_col_dims; ++i) {
output_dims.push_back(in_dims[i]);
}
output_dims.push_back(w_dims[1]);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("Input", "Out"); ctx->ShareLoD("Input", "Out");
...@@ -128,6 +119,9 @@ void FCOpMaker::Make() { ...@@ -128,6 +119,9 @@ void FCOpMaker::Make() {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Fully Connected Operator. Fully Connected Operator.
...@@ -142,13 +136,20 @@ class FCOpKernel : public framework::OpKernel<T> { ...@@ -142,13 +136,20 @@ class FCOpKernel : public framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
auto input = ctx.Input<Tensor>("Input"); auto input = ctx.Input<framework::LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W"); auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias"); auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<Tensor>("Out"); auto output = ctx.Output<framework::LoDTensor>("Out");
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
auto w_dims = w->dims(); auto w_dims = w->dims();
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
auto out_dims = output->dims(); auto out_dims = output->dims();
int M = framework::product(out_dims) / out_dims[out_dims.size() - 1]; int M = framework::product(out_dims) / w_dims[1];
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* w_data = w->data<T>(); const T* w_data = w->data<T>();
......
...@@ -48,5 +48,21 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -48,5 +48,21 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override; void Make() override;
}; };
inline void FCOutputSize(const framework::DDim& in_dims,
const framework::DDim& w_dims,
std::vector<int64_t>& out_dims, // NOLINT
int in_num_col_dims) {
auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims);
PADDLE_ENFORCE_EQ(
in_mat_dims[1], w_dims[0],
"Fully Connected input and weigth size do not match. %s, %s");
out_dims.reserve(static_cast<size_t>(in_num_col_dims + 1));
for (int i = 0; i < in_num_col_dims; ++i) {
out_dims.push_back(in_dims[i]);
}
out_dims.push_back(w_dims[1]);
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -88,7 +88,8 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,7 +88,8 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) " "(boolean, default false) "
"Sparse update.") "Sparse update.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "") AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
FusedEmbeddingSeqPool Operator. FusedEmbeddingSeqPool Operator.
......
...@@ -54,7 +54,8 @@ $$Out = scale * X$$ ...@@ -54,7 +54,8 @@ $$Out = scale * X$$
)DOC"); )DOC");
AddAttr<int>("num_hash", "").SetDefault(1); AddAttr<int>("num_hash", "").SetDefault(1);
AddAttr<int>("mod_by", "").SetDefault(100000); AddAttr<int>("mod_by", "").SetDefault(100000);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "") AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true); .SetDefault(true);
} }
}; };
......
...@@ -123,7 +123,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -123,7 +123,7 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
auto input = ctx.Input<Tensor>("Input"); auto input = ctx.Input<framework::LoDTensor>("Input");
auto w = ctx.Input<Tensor>("W"); auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias"); auto bias = ctx.Input<Tensor>("Bias");
...@@ -151,7 +151,13 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -151,7 +151,13 @@ class FCMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* w_data = w->data<T>(); const T* w_data = w->data<T>();
auto output = ctx.Output<Tensor>("Out"); auto output = ctx.Output<framework::LoDTensor>("Out");
int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
std::vector<int64_t> output_dims;
FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims);
output->Resize(framework::make_ddim(output_dims));
output->set_lod(input->lod());
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto dst_memory = mem.dst(output_data); auto dst_memory = mem.dst(output_data);
...@@ -204,19 +210,21 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -204,19 +210,21 @@ class FCMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input")); Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("W")); Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
const Tensor* input = ctx.Input<Tensor>("Input");
const T* input_data = input->data<T>();
const Tensor* w = ctx.Input<Tensor>("W");
const T* w_data = w->data<T>();
if (input_grad) { if (input_grad) {
input_grad->Resize(input->dims());
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
} }
if (w_grad) { if (w_grad) {
w_grad->Resize(w->dims());
w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace()); w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
} }
const Tensor* input = ctx.Input<Tensor>("Input");
const T* input_data = input->data<T>();
const Tensor* w = ctx.Input<Tensor>("W");
const T* w_data = w->data<T>();
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
......
...@@ -59,7 +59,8 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -59,7 +59,8 @@ class SequenceEnumerateOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.") AddAttr<int>("pad_value", "(int) The enumerate sequence padding value.")
.SetDefault(0); .SetDefault(0);
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "") AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
"Skip calling InferShape() function in the runtime.")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Sequence Enumerate Operator. Sequence Enumerate Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册