未验证 提交 a862debf 编写于 作者: W Wang Xin 提交者: GitHub

move sequence_mask op InferShape func (#53782)

* move sequence_mask op InferShape func

* add dtype infer
上级 2782b291
......@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
......@@ -21,21 +24,6 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceMask");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "SequenceMask");
int maxlen = ctx->Attrs().Get<int>("maxlen");
auto dim = phi::vectorize<int>(ctx->GetInputDim("X"));
if (ctx->HasInputs("MaxLenTensor")) {
dim.push_back(-1);
} else {
dim.push_back(maxlen > 0 ? maxlen : -1);
}
ctx->SetOutputDim("Y", phi::make_ddim(dim));
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -93,9 +81,14 @@ If maxlen < 0, maxlen = max(X)
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(sequence_mask,
SequenceMaskInferShapeFunctor,
PD_INFER_META(phi::SequenceMaskInferMeta));
REGISTER_OPERATOR(
sequence_mask,
paddle::operators::SequenceMaskOp,
paddle::operators::SequenceMaskOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
SequenceMaskInferShapeFunctor);
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
......@@ -2584,6 +2585,24 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
}
}
void SequenceMaskInferMeta(const MetaTensor& x,
const MetaTensor& max_len_tensor,
int maxlen,
int out_dtype,
MetaTensor* y) {
auto dim = phi::vectorize<int>(x.dims());
if (max_len_tensor) {
dim.push_back(-1);
} else {
dim.push_back(maxlen > 0 ? maxlen : -1);
}
y->set_dims(phi::make_ddim(dim));
auto out_phi_dtype = phi::TransToPhiDataType(out_dtype);
y->set_dtype(out_phi_dtype);
}
void SoftmaxMaskFuseInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out) {
......
......@@ -399,6 +399,12 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence,
bool right,
MetaTensor* out);
void SequenceMaskInferMeta(const MetaTensor& x,
const MetaTensor& max_len_tensor,
int maxlen,
int out_dtype,
MetaTensor* y);
void SoftmaxMaskFuseInferMeta(const MetaTensor& x,
const MetaTensor& mask,
MetaTensor* out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册