未验证 提交 7f5c14bc 编写于 作者: H hong 提交者: GitHub

[NewIR] new ir support assert op (#56353)

* fix op translator reshape type

* update

* new ir support vector type place transfer

* add test case

* update

* revert code

* add test assert new ir test

* update

* update
上级 d62dea9a
......@@ -151,9 +151,21 @@
outputs: []
no_need_buffer: null
data_transform: null
kernel:
func: [assert]
param: [cond, data, summarize]
backend: null
layout: null
data_type:
ordered: false
candidates: [cond]
to_complex_flag: [false]
dispatch: {assert: null}
force_backend: null
inplace: null
backward: null
- name: print
inputs:
- typename: Tensor
......
......@@ -703,13 +703,23 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
VLOG(6) << "need trans from " << place << " to "
<< kernel_key.backend();
// build memcopy op
new_in = AddPlaceTransferOp(
new_in,
new_in_type,
place,
phi::TransToPhiPlace(kernel.InputAt(i).backend),
kernel_key,
program.get());
auto out_place = phi::TransToPhiPlace(kernel.InputAt(i).backend);
auto new_in_alloc_type =
new_in_type.dyn_cast<dialect::AllocatedDenseTensorType>();
auto out_type = dialect::AllocatedDenseTensorType::get(
ctx,
out_place,
new_in_alloc_type.dtype(),
new_in_alloc_type.dims(),
new_in_alloc_type.data_layout(),
new_in_alloc_type.lod(),
new_in_alloc_type.offset());
new_in = AddPlaceTransferOp(new_in,
out_type,
place,
out_place,
kernel_key,
program.get());
}
} else if (new_in_type.isa<ir::VectorType>()) {
// [ todo need update here, support combine data transfomer]
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/assert_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/tensor_formatter.h"
namespace phi {
template <typename T, typename Context>
void AssertKernel(const Context& ctx,
const DenseTensor& cond,
const std::vector<const DenseTensor*>& data,
int64_t summarize) {
bool cond_flag = cond.data<bool>()[0];
if (cond_flag) {
return;
}
paddle::funcs::TensorFormatter formatter;
formatter.SetSummarize(summarize);
for (size_t i = 0; i < data.size(); ++i) {
std::string name = "data_" + std::to_string(i);
formatter.Print(*(data[i]), name);
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The condition of must be true, but received false"));
}
} // namespace phi
PD_REGISTER_KERNEL(assert, CPU, ALL_LAYOUT, phi::AssertKernel, bool) {
kernel->InputAt(0).SetBackend(phi::Backend::CPU);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void AssertKernel(const Context& ctx,
const DenseTensor& cond,
const std::vector<const DenseTensor*>& data,
int64_t summarize);
} // namespace phi
......@@ -15,6 +15,7 @@
import unittest
import numpy
from dygraph_to_static_util import test_and_compare_with_new_ir
import paddle
from paddle import fluid
......@@ -47,6 +48,7 @@ class TestAssertVariable(unittest.TestCase):
self._run(func, x, with_exception, True)
self._run(func, x, with_exception, False)
@test_and_compare_with_new_ir(False)
def test_non_variable(self):
self._run_dy_static(
dyfunc_assert_non_variable, x=False, with_exception=True
......@@ -55,6 +57,7 @@ class TestAssertVariable(unittest.TestCase):
dyfunc_assert_non_variable, x=True, with_exception=False
)
@test_and_compare_with_new_ir(False)
def test_bool_variable(self):
self._run_dy_static(
dyfunc_assert_variable, x=numpy.array([False]), with_exception=True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册