未验证 提交 95280a36 编写于 作者: S Sing_chan 提交者: GitHub

move trunc_op's infere shape to phi (#39772)

* move trunc_op's infere shape

* modify according to risheng's comment
上级 30992ea0
...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,14 +23,6 @@ namespace operators { ...@@ -21,14 +23,6 @@ namespace operators {
class TruncOp : public framework::OperatorWithKernel { class TruncOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "trunc");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "trunc");
auto input_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", input_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class TruncOpMaker : public framework::OpProtoAndCheckerMaker { class TruncOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -75,9 +69,13 @@ class TruncGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -75,9 +69,13 @@ class TruncGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(trunc, TruncInferShapeFunctor,
PT_INFER_META(phi::UnchangedInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker, REGISTER_OPERATOR(trunc, ops::TruncOp, ops::TruncOpMaker,
ops::TruncGradOpMaker<paddle::framework::OpDesc>, ops::TruncGradOpMaker<paddle::framework::OpDesc>,
ops::TruncGradOpMaker<paddle::imperative::OpBase>); ops::TruncGradOpMaker<paddle::imperative::OpBase>,
TruncInferShapeFunctor);
REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp); REGISTER_OPERATOR(trunc_grad, ops::TruncGradOp);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册