From c05af910bc1cfac95910abc34bfc742311225303 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Tue, 19 Mar 2019 14:36:52 +0800 Subject: [PATCH] refine cos_sim infershape test=develop --- paddle/fluid/operators/cos_sim_op.cc | 3 +++ paddle/fluid/operators/cos_sim_op.h | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/cos_sim_op.cc b/paddle/fluid/operators/cos_sim_op.cc index 8f3644039f9..30ec74d8442 100644 --- a/paddle/fluid/operators/cos_sim_op.cc +++ b/paddle/fluid/operators/cos_sim_op.cc @@ -74,6 +74,9 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { "Norm of the second input, reduced along the 1st " "dimension.") .AsIntermediate(); + AddAttr(framework::kAllKernelsMustComputeRuntimeShape, + "Skip calling InferShape() function in the runtime.") + .SetDefault(true); AddComment(R"DOC( **Cosine Similarity Operator** diff --git a/paddle/fluid/operators/cos_sim_op.h b/paddle/fluid/operators/cos_sim_op.h index 76cfc680518..0b4e3f77467 100644 --- a/paddle/fluid/operators/cos_sim_op.h +++ b/paddle/fluid/operators/cos_sim_op.h @@ -28,17 +28,21 @@ class CosSimKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { // get Tensor - auto* in_x = context.Input("X"); + auto* in_x = context.Input("X"); auto* in_y = context.Input("Y"); - auto* out_z = context.Output("Out"); + auto* out_z = context.Output("Out"); auto* out_x_norm = context.Output("XNorm"); auto* out_y_norm = context.Output("YNorm"); - out_z->mutable_data(context.GetPlace()); - out_x_norm->mutable_data(context.GetPlace()); - out_y_norm->mutable_data(context.GetPlace()); int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; + out_z->Resize({rows_x, 1}); + out_x_norm->Resize({rows_x, 1}); + out_y_norm->Resize({rows_y, 1}); + out_z->mutable_data(context.GetPlace()); + out_x_norm->mutable_data(context.GetPlace()); + out_y_norm->mutable_data(context.GetPlace()); + out_z->set_lod(in_x->lod()); int cols = framework::product(in_x->dims()) / rows_x; @@ -81,6 +85,7 @@ class CosSimGradKernel : public framework::OpKernel { if (rows_x == rows_y) { if (out_grad_x) { + out_grad_x->Resize(in_x->dims()); math::CosSimGradFunctor functor( in_x_norm->data(), in_y_norm->data(), in_x->data(), in_y->data(), in_z->data(), in_grad_z->data(), @@ -91,6 +96,7 @@ class CosSimGradKernel : public framework::OpKernel { for_range(functor); } if (out_grad_y) { + out_grad_y->Resize(in_y->dims()); math::CosSimGradFunctor functor( in_y_norm->data(), in_x_norm->data(), in_y->data(), in_x->data(), in_z->data(), in_grad_z->data(), @@ -102,6 +108,7 @@ class CosSimGradKernel : public framework::OpKernel { } } else { if (out_grad_x) { + out_grad_x->Resize(in_x->dims()); math::CosSimDxFunctor functor( in_x_norm->data(), in_y_norm->data(), in_x->data(), in_y->data(), in_z->data(), in_grad_z->data(), @@ -112,6 +119,7 @@ class CosSimGradKernel : public framework::OpKernel { for_range(functor); } if (out_grad_y) { + out_grad_y->Resize(in_y->dims()); out_grad_y->mutable_data(context.GetPlace()); math::SetConstant set_zero; auto& dev_ctx = context.template device_context(); -- GitLab