未验证 提交 c5285cc5 编写于 作者: F From00 提交者: GitHub

Add yaml for flatten_contiguous_range OP (#41345)

* Add yaml for flatten_contiguous_range OP

* update

* Fix typos
Co-authored-by: NShixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
上级 3152f3fb
......@@ -21,8 +21,8 @@ namespace phi {
template <typename T, typename Context>
void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& xshape,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto xshape_dims = xshape.dims();
dev_ctx.Alloc(x_grad, out_grad.dtype());
......
......@@ -20,8 +20,8 @@ namespace phi {
template <typename T, typename Context>
void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& xshape,
const DenseTensor& out_grad,
DenseTensor* x_grad);
} // namespace phi
......@@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")});
"flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")});
}
} // namespace phi
......
......@@ -12,7 +12,6 @@ cc_test(test_dot_api SRCS test_dot_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS})
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
PD_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT);
namespace paddle {
namespace tests {
namespace framework = paddle::framework;
using DDim = phi::DDim;
// TODO(chenweihang): Remove this test after the API is used in the dygraph
TEST(API, flatten) {
// 1. create tensor
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<phi::DenseTensor>(
alloc.get(),
phi::DenseTensorMeta(phi::DataType::FLOAT32,
phi::make_ddim({3, 2, 2, 3}),
phi::DataLayout::NCHW));
auto* dense_x_data =
dense_x->mutable_data<float>(paddle::platform::CPUPlace());
for (int i = 0; i < dense_x->numel(); i++) {
dense_x_data[i] = i;
}
paddle::experimental::Tensor x(dense_x);
int start_axis = 1, stop_axis = 2;
// 2. test API
auto out = paddle::experimental::flatten(x, start_axis, stop_axis);
// 3. check result
std::vector<int> expect_shape = {3, 4, 3};
ASSERT_EQ(out.dims()[0], expect_shape[0]);
ASSERT_EQ(out.dims()[1], expect_shape[1]);
ASSERT_EQ(out.dims()[2], expect_shape[2]);
ASSERT_EQ(out.numel(), 36);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), phi::DataType::FLOAT32);
ASSERT_EQ(out.layout(), phi::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
bool value_equal = true;
auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(out.impl());
auto* dense_out_data = dense_out->data<float>();
for (int i = 0; i < dense_x->numel(); i++) {
if (std::abs(dense_x_data[i] - dense_out_data[i]) > 1e-6f)
value_equal = false;
}
ASSERT_EQ(value_equal, true);
}
} // namespace tests
} // namespace paddle
......@@ -23,6 +23,8 @@ from op_test import OpTest
class TestFlattenOp(OpTest):
def setUp(self):
self.python_api = paddle.flatten
self.python_out_sig = ["Out"]
self.op_type = "flatten_contiguous_range"
self.start_axis = 0
self.stop_axis = -1
......@@ -35,10 +37,10 @@ class TestFlattenOp(OpTest):
}
def test_check_output(self):
self.check_output(no_check_set=["XShape"])
self.check_output(no_check_set=["XShape"], check_eager=True)
def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", check_eager=True)
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
......
......@@ -676,7 +676,11 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis")
if paddle.in_dynamic_mode():
if in_dygraph_mode():
dy_out, _ = _C_ops.final_state_flatten(x, start_axis, stop_axis)
return dy_out
if _in_legacy_dygraph():
dy_out, _ = _C_ops.flatten_contiguous_range(x, 'start_axis', start_axis,
'stop_axis', stop_axis)
return dy_out
......
......@@ -547,11 +547,15 @@
- api : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor
output : Tensor(out), Tensor(xshape)
infer_meta :
func : FlattenInferMeta
func : FlattenWithXShapeInferMeta
kernel :
func : flatten
func : flatten_with_xshape
backend : x
inplace : (x -> out)
view : (x -> out)
backward : flatten_grad
# flip
- api : flip
......
......@@ -349,6 +349,19 @@
kernel :
func : expm1_grad
- backward_api : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
kernel :
func : flatten_grad
data_type: out_grad
backend: out_grad
layout: out_grad
- backward_api : floor_grad
forward : floor(Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......
{
"phi_apis":["conj", "nll_loss"],
"phi_apis":["conj", "nll_loss", "flatten"],
"phi_kernels":["equal_all"]
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册