未验证 提交 3dc99088 编写于 作者: H huangjiyi 提交者: GitHub

move fusion_group infershape to phi (#53934)

* update

* update

* update

* set out dtype
上级 bcf67536
...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,60 +23,6 @@ class FusionGroupOp : public framework::OperatorWithKernel { ...@@ -21,60 +23,6 @@ class FusionGroupOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Inputs"), "Input", "Inputs", "FusionGroup");
OP_INOUT_CHECK(ctx->HasOutputs("Outs"), "Output", "Outs", "FusionGroup");
auto input_names = ctx->Inputs("Inputs");
auto output_names = ctx->Outputs("Outs");
const size_t num_ins = input_names.size();
const size_t num_outs = output_names.size();
PADDLE_ENFORCE_GE(
num_ins,
1UL,
platform::errors::InvalidArgument(
"Expected the number of inputs >= 1. Received %d.", num_ins));
PADDLE_ENFORCE_GE(
num_outs,
1UL,
platform::errors::InvalidArgument(
"Expected the number of outputs >= 1. Recived %d.", num_outs));
int type = ctx->Attrs().Get<int>("type");
PADDLE_ENFORCE_EQ(type,
0UL,
platform::errors::InvalidArgument(
"Only support fusion of elementwise operations."));
std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs");
if (type == 0) {
for (size_t i = 1; i < num_ins; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[0],
x_dims[i],
platform::errors::InvalidArgument(
"All the inputs' dims is expected to be the same. "
"But received [%s] (name: %s) vs [%s] (name: %s).",
x_dims[0],
input_names[0],
x_dims[i],
input_names[i]));
}
std::vector<framework::DDim> out_dims;
for (size_t j = 0; j < num_outs; ++j) {
out_dims.push_back(x_dims[0]);
}
ctx->SetOutputsDim("Outs", out_dims);
}
// Only lod of Inputs[0] would be shared with Outs.
for (size_t j = 0; j < num_outs; ++j) {
ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j);
}
}
protected: protected:
phi::KernelKey GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -115,5 +63,12 @@ multiple operators into one. It supports several types: ...@@ -115,5 +63,12 @@ multiple operators into one. It supports several types:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(fusion_group,
FusionGroupInferShapeFunctor,
PD_INFER_META(phi::FusionGroupInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_group, ops::FusionGroupOp, ops::FusionGroupOpMaker); REGISTER_OPERATOR(fusion_group,
ops::FusionGroupOp,
ops::FusionGroupOpMaker,
FusionGroupInferShapeFunctor);
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h" #include "paddle/phi/kernels/funcs/concat_funcs.h"
...@@ -1313,6 +1314,71 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x, ...@@ -1313,6 +1314,71 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
} }
} }
void FusionGroupInferMeta(const std::vector<const MetaTensor*>& ins,
const std::vector<int>& outs_dtype,
const std::vector<int>& inputs_dtype,
const std::string& func_name,
int type,
std::vector<MetaTensor*> outs) {
const size_t num_ins = ins.size();
const size_t num_outs = outs.size();
PADDLE_ENFORCE_GE(
num_ins,
1UL,
phi::errors::InvalidArgument(
"Expected the number of inputs >= 1. Received %d.", num_ins));
PADDLE_ENFORCE_GE(
num_outs,
1UL,
phi::errors::InvalidArgument(
"Expected the number of outputs >= 1. Recived %d.", num_outs));
PADDLE_ENFORCE_EQ(type,
0UL,
phi::errors::InvalidArgument(
"Only support fusion of elementwise operations."));
std::vector<phi::DDim> x_dims;
for (size_t i = 0; i < num_ins; ++i) {
x_dims.push_back(ins[i]->dims());
}
if (type == 0) {
for (size_t i = 1; i < num_ins; ++i) {
PADDLE_ENFORCE_EQ(x_dims[0],
x_dims[i],
phi::errors::InvalidArgument(
"All the inputs' dims is expected to be the same. "
"But received [%s] (name: %s) vs [%s] (name: %s).",
x_dims[0],
ins[0],
x_dims[i],
ins[i]));
}
for (size_t j = 0; j < num_outs; ++j) {
outs[j]->set_dims(x_dims[0]);
}
}
// Only lod of Inputs[0] would be shared with Outs.
for (size_t j = 0; j < num_outs; ++j) {
outs[j]->share_lod(*ins[0]);
}
for (size_t j = 0; j < num_outs; ++j) {
if (outs_dtype[j] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
outs[j]->set_dtype(phi::DataType::FLOAT16);
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
outs[j]->set_dtype(phi::DataType::FLOAT32);
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
outs[j]->set_dtype(phi::DataType::FLOAT64);
}
}
}
void GenerateProposalsV2InferMeta(const MetaTensor& scores, void GenerateProposalsV2InferMeta(const MetaTensor& scores,
const MetaTensor& bbox_deltas, const MetaTensor& bbox_deltas,
const MetaTensor& im_shape, const MetaTensor& im_shape,
......
...@@ -268,6 +268,13 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x, ...@@ -268,6 +268,13 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
MetaTensor* dweight_out, MetaTensor* dweight_out,
MetaTensor* dbias_out); MetaTensor* dbias_out);
void FusionGroupInferMeta(const std::vector<const MetaTensor*>& ins,
const std::vector<int>& outs_dtype,
const std::vector<int>& inputs_dtype,
const std::string& func_name,
int type,
std::vector<MetaTensor*> outs);
void GenerateProposalsV2InferMeta(const MetaTensor& scores, void GenerateProposalsV2InferMeta(const MetaTensor& scores,
const MetaTensor& bbox_deltas, const MetaTensor& bbox_deltas,
const MetaTensor& im_shape, const MetaTensor& im_shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册