未验证 提交 a862debf 编写于 作者: W Wang Xin 提交者: GitHub

move sequence_mask op InferShape func (#53782)

* move sequence_mask op InferShape func

* add dtype infer
上级 2782b291
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,21 +24,6 @@ class SequenceMaskOp : public framework::OperatorWithKernel { ...@@ -21,21 +24,6 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceMask");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "SequenceMask");
int maxlen = ctx->Attrs().Get<int>("maxlen");
auto dim = phi::vectorize<int>(ctx->GetInputDim("X"));
if (ctx->HasInputs("MaxLenTensor")) {
dim.push_back(-1);
} else {
dim.push_back(maxlen > 0 ? maxlen : -1);
}
ctx->SetOutputDim("Y", phi::make_ddim(dim));
}
protected: protected:
phi::KernelKey GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -93,9 +81,14 @@ If maxlen < 0, maxlen = max(X) ...@@ -93,9 +81,14 @@ If maxlen < 0, maxlen = max(X)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(sequence_mask,
SequenceMaskInferShapeFunctor,
PD_INFER_META(phi::SequenceMaskInferMeta));
REGISTER_OPERATOR( REGISTER_OPERATOR(
sequence_mask, sequence_mask,
paddle::operators::SequenceMaskOp, paddle::operators::SequenceMaskOp,
paddle::operators::SequenceMaskOpMaker, paddle::operators::SequenceMaskOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
SequenceMaskInferShapeFunctor);
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/phi/common/type_traits.h" #include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/axis_utils.h"
...@@ -2584,6 +2585,24 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, ...@@ -2584,6 +2585,24 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
} }
} }
void SequenceMaskInferMeta(const MetaTensor& x,
const MetaTensor& max_len_tensor,
int maxlen,
int out_dtype,
MetaTensor* y) {
auto dim = phi::vectorize<int>(x.dims());
if (max_len_tensor) {
dim.push_back(-1);
} else {
dim.push_back(maxlen > 0 ? maxlen : -1);
}
y->set_dims(phi::make_ddim(dim));
auto out_phi_dtype = phi::TransToPhiDataType(out_dtype);
y->set_dtype(out_phi_dtype);
}
void SoftmaxMaskFuseInferMeta(const MetaTensor& x, void SoftmaxMaskFuseInferMeta(const MetaTensor& x,
const MetaTensor& mask, const MetaTensor& mask,
MetaTensor* out) { MetaTensor* out) {
......
...@@ -399,6 +399,12 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, ...@@ -399,6 +399,12 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
bool right, bool right,
MetaTensor* out); MetaTensor* out);
void SequenceMaskInferMeta(const MetaTensor& x,
const MetaTensor& max_len_tensor,
int maxlen,
int out_dtype,
MetaTensor* y);
void SoftmaxMaskFuseInferMeta(const MetaTensor& x, void SoftmaxMaskFuseInferMeta(const MetaTensor& x,
const MetaTensor& mask, const MetaTensor& mask,
MetaTensor* out); MetaTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册