未验证 提交 3dc99088 编写于 作者: H huangjiyi 提交者: GitHub

move fusion_group infershape to phi (#53934)

* update

* update

* update

* set out dtype
上级 bcf67536
......@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -21,60 +23,6 @@ class FusionGroupOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Inputs"), "Input", "Inputs", "FusionGroup");
OP_INOUT_CHECK(ctx->HasOutputs("Outs"), "Output", "Outs", "FusionGroup");
auto input_names = ctx->Inputs("Inputs");
auto output_names = ctx->Outputs("Outs");
const size_t num_ins = input_names.size();
const size_t num_outs = output_names.size();
PADDLE_ENFORCE_GE(
num_ins,
1UL,
platform::errors::InvalidArgument(
"Expected the number of inputs >= 1. Received %d.", num_ins));
PADDLE_ENFORCE_GE(
num_outs,
1UL,
platform::errors::InvalidArgument(
"Expected the number of outputs >= 1. Recived %d.", num_outs));
int type = ctx->Attrs().Get<int>("type");
PADDLE_ENFORCE_EQ(type,
0UL,
platform::errors::InvalidArgument(
"Only support fusion of elementwise operations."));
std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs");
if (type == 0) {
for (size_t i = 1; i < num_ins; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[0],
x_dims[i],
platform::errors::InvalidArgument(
"All the inputs' dims is expected to be the same. "
"But received [%s] (name: %s) vs [%s] (name: %s).",
x_dims[0],
input_names[0],
x_dims[i],
input_names[i]));
}
std::vector<framework::DDim> out_dims;
for (size_t j = 0; j < num_outs; ++j) {
out_dims.push_back(x_dims[0]);
}
ctx->SetOutputsDim("Outs", out_dims);
}
// Only lod of Inputs[0] would be shared with Outs.
for (size_t j = 0; j < num_outs; ++j) {
ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j);
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -115,5 +63,12 @@ multiple operators into one. It supports several types:
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(fusion_group,
FusionGroupInferShapeFunctor,
PD_INFER_META(phi::FusionGroupInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_group, ops::FusionGroupOp, ops::FusionGroupOpMaker);
REGISTER_OPERATOR(fusion_group,
ops::FusionGroupOp,
ops::FusionGroupOpMaker,
FusionGroupInferShapeFunctor);
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
......@@ -1313,6 +1314,71 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
}
}
void FusionGroupInferMeta(const std::vector<const MetaTensor*>& ins,
const std::vector<int>& outs_dtype,
const std::vector<int>& inputs_dtype,
const std::string& func_name,
int type,
std::vector<MetaTensor*> outs) {
const size_t num_ins = ins.size();
const size_t num_outs = outs.size();
PADDLE_ENFORCE_GE(
num_ins,
1UL,
phi::errors::InvalidArgument(
"Expected the number of inputs >= 1. Received %d.", num_ins));
PADDLE_ENFORCE_GE(
num_outs,
1UL,
phi::errors::InvalidArgument(
"Expected the number of outputs >= 1. Recived %d.", num_outs));
PADDLE_ENFORCE_EQ(type,
0UL,
phi::errors::InvalidArgument(
"Only support fusion of elementwise operations."));
std::vector<phi::DDim> x_dims;
for (size_t i = 0; i < num_ins; ++i) {
x_dims.push_back(ins[i]->dims());
}
if (type == 0) {
for (size_t i = 1; i < num_ins; ++i) {
PADDLE_ENFORCE_EQ(x_dims[0],
x_dims[i],
phi::errors::InvalidArgument(
"All the inputs' dims is expected to be the same. "
"But received [%s] (name: %s) vs [%s] (name: %s).",
x_dims[0],
ins[0],
x_dims[i],
ins[i]));
}
for (size_t j = 0; j < num_outs; ++j) {
outs[j]->set_dims(x_dims[0]);
}
}
// Only lod of Inputs[0] would be shared with Outs.
for (size_t j = 0; j < num_outs; ++j) {
outs[j]->share_lod(*ins[0]);
}
for (size_t j = 0; j < num_outs; ++j) {
if (outs_dtype[j] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
outs[j]->set_dtype(phi::DataType::FLOAT16);
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
outs[j]->set_dtype(phi::DataType::FLOAT32);
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
outs[j]->set_dtype(phi::DataType::FLOAT64);
}
}
}
void GenerateProposalsV2InferMeta(const MetaTensor& scores,
const MetaTensor& bbox_deltas,
const MetaTensor& im_shape,
......
......@@ -268,6 +268,13 @@ void FusedLinearParamGradAddInferMeta(const MetaTensor& x,
MetaTensor* dweight_out,
MetaTensor* dbias_out);
void FusionGroupInferMeta(const std::vector<const MetaTensor*>& ins,
const std::vector<int>& outs_dtype,
const std::vector<int>& inputs_dtype,
const std::string& func_name,
int type,
std::vector<MetaTensor*> outs);
void GenerateProposalsV2InferMeta(const MetaTensor& scores,
const MetaTensor& bbox_deltas,
const MetaTensor& im_shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册