未验证 提交 95280a36 编写于 作者: S Sing_chan 提交者: GitHub

move trunc_op's infere shape to phi (#39772)

* move trunc_op's infere shape

* modify according to risheng's comment
上级 30992ea0
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -21,14 +23,6 @@ namespace operators {
class TruncOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc");
auto input_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", input_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class TruncOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -75,9 +69,13 @@ class TruncGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(trunc, TruncInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker,
ops::TruncGradOpMaker<paddle::framework::OpDesc>,
ops::TruncGradOpMaker<paddle::imperative::OpBase>);
ops::TruncGradOpMaker<paddle::imperative::OpBase>,
TruncInferShapeFunctor);
REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册