未验证 提交 f014e301 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] add int64_t kernels for YoloV3, test=develop (#35045)

* [NPU] add int64 kernels, test=develop

* update ci scripts to be able to trun WITH_ASCEND_INT64 on, test=develop
上级 f13dcfb1
...@@ -224,6 +224,7 @@ option(WITH_UNITY_BUILD "Compile with UnityBuild mode" OFF) ...@@ -224,6 +224,7 @@ option(WITH_UNITY_BUILD "Compile with UnityBuild mode" OFF)
option(WITH_STRIP "Strip so files of Whl packages" OFF) option(WITH_STRIP "Strip so files of Whl packages" OFF)
option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF) option(NEW_RELEASE_CUBIN "PaddlePaddle next-level release strategy for pypi cubin package" OFF)
option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF) option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup jit package" OFF)
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
# PY_VERSION # PY_VERSION
if(NOT PY_VERSION) if(NOT PY_VERSION)
......
...@@ -90,6 +90,10 @@ if(WITH_ASCEND_CL) ...@@ -90,6 +90,10 @@ if(WITH_ASCEND_CL)
add_definitions(-DPADDLE_WITH_ASCEND_CL) add_definitions(-DPADDLE_WITH_ASCEND_CL)
endif() endif()
if(WITH_ASCEND_INT64)
add_definitions(-DPADDLE_WITH_ASCEND_INT64)
endif()
if(WITH_XPU) if(WITH_XPU)
message(STATUS "Compile with XPU!") message(STATUS "Compile with XPU!")
add_definitions(-DPADDLE_WITH_XPU) add_definitions(-DPADDLE_WITH_XPU)
......
...@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/fill_constant_op.h" #include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/utils.h"
...@@ -22,7 +19,7 @@ limitations under the License. */ ...@@ -22,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class FillConstantNPUKernel : public framework::OpKernel<T> { class FillConstantNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -32,7 +29,6 @@ class FillConstantNPUKernel : public framework::OpKernel<T> { ...@@ -32,7 +29,6 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
auto float_value = ctx.Attr<float>("value"); auto float_value = ctx.Attr<float>("value");
auto* out_var = ctx.Output<framework::Tensor>("Out"); auto* out_var = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
...@@ -63,25 +59,28 @@ class FillConstantNPUKernel : public framework::OpKernel<T> { ...@@ -63,25 +59,28 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
} }
auto shape = GetShape(ctx); auto shape = GetShape(ctx);
Tensor tensor_tmp(data_type); Tensor tensor_value(data_type);
tensor_tmp.mutable_data<T>({1}, ctx.GetPlace()); tensor_value.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_tmp, value); FillNpuTensorWithConstant<T>(&tensor_value, value);
out_var->mutable_data<T>(shape, ctx.GetPlace());
out_var->mutable_data<T>(shape, place); NpuOpRunner runner;
const auto& runner = NpuOpRunner("FillD", {tensor_tmp}, {*out_var}, runner.SetType("Fill")
{{"dims", framework::vectorize(shape)}}); .AddInput(framework::vectorize(shape))
runner.Run(stream); .AddInput(tensor_value)
.AddOutput(*out_var)
.Run(stream);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
fill_constant, fill_constant, paddle::operators::FillConstantNPUKernel<float>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, float>, paddle::operators::FillConstantNPUKernel<bool>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, bool>, paddle::operators::FillConstantNPUKernel<int>,
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, int>, #ifdef PADDLE_WITH_ASCEND_INT64
ops::FillConstantNPUKernel<paddle::platform::NPUDeviceContext, paddle::operators::FillConstantNPUKernel<int64_t>,
paddle::platform::float16>); #endif
paddle::operators::FillConstantNPUKernel<paddle::platform::float16>);
...@@ -14,22 +14,11 @@ ...@@ -14,22 +14,11 @@
#include "paddle/fluid/operators/increment_op.h" #include "paddle/fluid/operators/increment_op.h"
#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
class OpDesc;
class Variable;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
} // namespace paddle
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class IncrementalNPUKernel : public framework::OpKernel<T> { class IncrementalNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -56,13 +45,11 @@ class IncrementalNPUKernel : public framework::OpKernel<T> { ...@@ -56,13 +45,11 @@ class IncrementalNPUKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
increment, increment, paddle::operators::IncrementalNPUKernel<float>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, float>, paddle::operators::IncrementalNPUKernel<double>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, double>, paddle::operators::IncrementalNPUKernel<int>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, int>, #ifdef PADDLE_WITH_ASCEND_INT64
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, paddle::operators::IncrementalNPUKernel<int64_t>,
plat::float16>) #endif
paddle::operators::IncrementalNPUKernel<paddle::platform::float16>)
...@@ -18,7 +18,7 @@ limitations under the Licnse. */ ...@@ -18,7 +18,7 @@ limitations under the Licnse. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class MeshgridNPUKernel : public framework::OpKernel<T> { class MeshgridNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -65,9 +65,12 @@ class MeshgridNPUKernel : public framework::OpKernel<T> { ...@@ -65,9 +65,12 @@ class MeshgridNPUKernel : public framework::OpKernel<T> {
auto stream = auto stream =
context.template device_context<paddle::platform::NPUDeviceContext>() context.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
const auto& runner = NpuOpRunner("BroadcastToD", {reshape_ins_tensor}, NpuOpRunner runner;
{*(outs[i])}, {{"shape", shape}}); runner.SetType("BroadcastTo")
runner.Run(stream); .AddInput(reshape_ins_tensor)
.AddInput(std::move(shape))
.AddOutput(*(outs[i]))
.Run(stream);
} }
} }
}; };
...@@ -75,10 +78,10 @@ class MeshgridNPUKernel : public framework::OpKernel<T> { ...@@ -75,10 +78,10 @@ class MeshgridNPUKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
meshgrid, ops::MeshgridNPUKernel<plat::NPUDeviceContext, float>, meshgrid, paddle::operators::MeshgridNPUKernel<int>,
ops::MeshgridNPUKernel<plat::NPUDeviceContext, plat::float16>, #ifdef PADDLE_WITH_ASCEND_INT64
ops::MeshgridNPUKernel<plat::NPUDeviceContext, int32_t>); paddle::operators::MeshgridNPUKernel<int64_t>,
#endif
paddle::operators::MeshgridNPUKernel<float>,
paddle::operators::MeshgridNPUKernel<paddle::platform::float16>);
...@@ -12,24 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,24 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/range_op.h" #include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class RangeNPUKernel : public framework::OpKernel<T> { class RangeNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -41,19 +30,19 @@ class RangeNPUKernel : public framework::OpKernel<T> { ...@@ -41,19 +30,19 @@ class RangeNPUKernel : public framework::OpKernel<T> {
framework::Tensor n; framework::Tensor n;
framework::TensorCopy( framework::TensorCopy(
*start_t, platform::CPUPlace(), *start_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n); context.template device_context<platform::NPUDeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>() context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait(); .Wait();
T start = n.data<T>()[0]; T start = n.data<T>()[0];
framework::TensorCopy( framework::TensorCopy(
*end_t, platform::CPUPlace(), *end_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n); context.template device_context<platform::NPUDeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>() context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait(); .Wait();
T end = n.data<T>()[0]; T end = n.data<T>()[0];
framework::TensorCopy( framework::TensorCopy(
*step_t, platform::CPUPlace(), *step_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n); context.template device_context<platform::NPUDeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>() context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait(); .Wait();
T step = n.data<T>()[0]; T step = n.data<T>()[0];
...@@ -78,11 +67,9 @@ class RangeNPUKernel : public framework::OpKernel<T> { ...@@ -78,11 +67,9 @@ class RangeNPUKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL(range, paddle::operators::RangeNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
REGISTER_OP_NPU_KERNEL( paddle::operators::RangeNPUKernel<int64_t>,
range, ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, double>)
#endif #endif
paddle::operators::RangeNPUKernel<float>,
paddle::operators::RangeNPUKernel<double>)
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class ScaleNPUKernel : public framework::OpKernel<T> { class ScaleNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -40,37 +40,21 @@ class ScaleNPUKernel : public framework::OpKernel<T> { ...@@ -40,37 +40,21 @@ class ScaleNPUKernel : public framework::OpKernel<T> {
auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor"); auto* scale_tensor = ctx.Input<framework::Tensor>("ScaleTensor");
scale = static_cast<float>(GetAttrFromTensor<T>(scale_tensor)); scale = static_cast<float>(GetAttrFromTensor<T>(scale_tensor));
} }
if (bias_after_scale) {
out->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Power", {*x}, {*out},
{{"power", power}, {"scale", scale}, {"shift", bias}});
runner.Run(stream); if (!bias_after_scale) {
} else { bias *= scale;
Tensor tmp_x(x->type());
tmp_x.Resize(x->dims());
tmp_x.mutable_data<T>(ctx.GetPlace());
const auto& runner_tmp =
NpuOpRunner("Adds", {*x}, {tmp_x}, {{"value", bias}});
runner_tmp.Run(stream);
out->mutable_data<T>(ctx.GetPlace());
float _bias = 0.0;
const auto& runner =
NpuOpRunner("Power", {tmp_x}, {*out},
{{"power", power}, {"scale", scale}, {"shift", _bias}});
runner.Run(stream);
} }
out->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Power", {*x}, {*out},
{{"power", power}, {"scale", scale}, {"shift", bias}});
runner.Run(stream);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
scale, ops::ScaleNPUKernel<paddle::platform::NPUDeviceContext, float>, scale, paddle::operators::ScaleNPUKernel<float>,
ops::ScaleNPUKernel<paddle::platform::NPUDeviceContext, paddle::operators::ScaleNPUKernel<paddle::platform::float16>);
paddle::platform::float16>);
...@@ -20,7 +20,7 @@ namespace operators { ...@@ -20,7 +20,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename T>
class StackNPUKernel : public framework::OpKernel<T> { class StackNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -49,7 +49,7 @@ class StackNPUKernel : public framework::OpKernel<T> { ...@@ -49,7 +49,7 @@ class StackNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class StackGradNPUKernel : public framework::OpKernel<T> { class StackGradNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -81,15 +81,18 @@ class StackGradNPUKernel : public framework::OpKernel<T> { ...@@ -81,15 +81,18 @@ class StackGradNPUKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
stack, ops::StackNPUKernel<paddle::platform::NPUDeviceContext, float>, stack, paddle::operators::StackNPUKernel<int>,
ops::StackNPUKernel<paddle::platform::NPUDeviceContext, #ifdef PADDLE_WITH_ASCEND_INT64
paddle::platform::float16>); paddle::operators::StackNPUKernel<int64_t>,
#endif
paddle::operators::StackNPUKernel<float>,
paddle::operators::StackNPUKernel<paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
stack_grad, stack_grad, paddle::operators::StackNPUKernel<int>,
ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext, float>, #ifdef PADDLE_WITH_ASCEND_INT64
ops::StackGradNPUKernel<paddle::platform::NPUDeviceContext, paddle::operators::StackNPUKernel<int64_t>,
paddle::platform::float16>); #endif
paddle::operators::StackGradNPUKernel<float>,
paddle::operators::StackGradNPUKernel<paddle::platform::float16>);
...@@ -228,6 +228,7 @@ function cmake_base() { ...@@ -228,6 +228,7 @@ function cmake_base() {
-DWITH_ARM=${WITH_ARM:-OFF} -DWITH_ARM=${WITH_ARM:-OFF}
-DWITH_ASCEND=${WITH_ASCEND:-OFF} -DWITH_ASCEND=${WITH_ASCEND:-OFF}
-DWITH_ASCEND_CL=${WITH_ASCEND_CL:-OFF} -DWITH_ASCEND_CL=${WITH_ASCEND_CL:-OFF}
-DWITH_ASCEND_INT64=${WITH_ASCEND_INT64:-OFF}
-DWITH_STRIP=${WITH_STRIP:-ON} -DWITH_STRIP=${WITH_STRIP:-ON}
-DON_INFER=${ON_INFER:-OFF} -DON_INFER=${ON_INFER:-OFF}
======================================== ========================================
...@@ -269,6 +270,7 @@ EOF ...@@ -269,6 +270,7 @@ EOF
-DWITH_ARM=${WITH_ARM:-OFF} \ -DWITH_ARM=${WITH_ARM:-OFF} \
-DWITH_ASCEND=${WITH_ASCEND:-OFF} \ -DWITH_ASCEND=${WITH_ASCEND:-OFF} \
-DWITH_ASCEND_CL=${WITH_ASCEND_CL:-OFF} \ -DWITH_ASCEND_CL=${WITH_ASCEND_CL:-OFF} \
-DWITH_ASCEND_INT64=${WITH_ASCEND_INT64:-OFF} \
-DWITH_STRIP=${WITH_STRIP:-ON} \ -DWITH_STRIP=${WITH_STRIP:-ON} \
-DON_INFER=${ON_INFER:-OFF} \ -DON_INFER=${ON_INFER:-OFF} \
-DWITH_UNITY_BUILD=${WITH_UNITY_BUILD:-OFF};build_error=$? -DWITH_UNITY_BUILD=${WITH_UNITY_BUILD:-OFF};build_error=$?
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -72,6 +72,30 @@ class TestFillConstantInt(OpTest): ...@@ -72,6 +72,30 @@ class TestFillConstantInt(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
class TestFillConstantInt64(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {
'shape': [123, 92],
'value': 1,
'dtype': core.VarDesc.VarType.INT64
}
self.outputs = {'Out': np.full((123, 92), 1).astype(self.dtype)}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.int64
def test_check_output(self):
self.check_output_with_place(self.place)
class TestFillConstantFP16(OpTest): class TestFillConstantFP16(OpTest):
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
......
...@@ -81,6 +81,32 @@ class TestIncrementFP16(OpTest): ...@@ -81,6 +81,32 @@ class TestIncrementFP16(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
class TestIncrementINT64(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(NPUPlace)
self.op_type = "increment"
self.init_dtype()
self.inputs = {
'X':
OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)),
}
self.pre_input_id = id(self.inputs['X'])
self.attrs = {"Step": 1}
self.outputs = {'Out': np.array([2])}
def set_npu(self):
self.__class__.use_npu = True
def init_dtype(self):
self.dtype = np.int64
def test_check_output(self):
self.check_output_with_place(self.place)
class TestIncrementInplace(unittest.TestCase): class TestIncrementInplace(unittest.TestCase):
def test_npu(self): def test_npu(self):
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
......
...@@ -75,6 +75,16 @@ class TestMeshgridOpFP16(TestMeshgridOp): ...@@ -75,6 +75,16 @@ class TestMeshgridOpFP16(TestMeshgridOp):
return "float16" return "float16"
class TestMeshgridOpINT32(TestMeshgridOp):
def get_dtype(self):
return "int32"
class TestMeshgridOpINT64(TestMeshgridOp):
def get_dtype(self):
return "int64"
class TestMeshgridOp2(TestMeshgridOp): class TestMeshgridOp2(TestMeshgridOp):
def get_x_shape(self): def get_x_shape(self):
return [100, 300] return [100, 300]
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
paddle.enable_static()
class TestRangeOp(OpTest):
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def setUp(self):
self.set_npu()
self.op_type = "range"
self.init_config()
self.inputs = {
'Start': np.array([self.case[0]]).astype(self.dtype),
'End': np.array([self.case[1]]).astype(self.dtype),
'Step': np.array([self.case[2]]).astype(self.dtype)
}
self.outputs = {
'Out': np.arange(self.case[0], self.case[1],
self.case[2]).astype(self.dtype)
}
def init_config(self):
self.dtype = np.float32
self.case = (0, 1, 0.2)
def test_check_output(self):
self.check_output_with_place(self.place)
class TestFloatRangeOpCase0(TestRangeOp):
def init_config(self):
self.dtype = np.float32
self.case = (0, 5, 1)
class TestInt32RangeOpCase0(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (0, 5, 2)
class TestInt32RangeOpCase1(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (10, 1, -2)
class TestInt32RangeOpCase2(TestRangeOp):
def init_config(self):
self.dtype = np.int32
self.case = (-1, -10, -2)
class TestInt64RangeOpCase0(TestRangeOp):
def init_config(self):
self.dtype = np.int64
self.case = (0, 5, 2)
class TestInt64RangeOpCase1(TestRangeOp):
def init_config(self):
self.dtype = np.int64
self.case = (10, 1, -2)
class TestInt64RangeOpCase2(TestRangeOp):
def init_config(self):
self.dtype = np.int64
self.case = (-1, -10, -2)
if __name__ == "__main__":
unittest.main()
...@@ -72,6 +72,8 @@ class TestStackOpBase(OpTest): ...@@ -72,6 +72,8 @@ class TestStackOpBase(OpTest):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.int32 or self.dtype == np.int64:
return
self.check_grad_with_place(self.place, self.get_x_names(), 'Y') self.check_grad_with_place(self.place, self.get_x_names(), 'Y')
...@@ -105,6 +107,16 @@ class TestStackOp6(TestStackOpBase): ...@@ -105,6 +107,16 @@ class TestStackOp6(TestStackOpBase):
self.axis = 3 self.axis = 3
class TestStackOpINT32(TestStackOpBase):
def init_dtype(self):
self.dtype = np.int32
class TestStackOpINT64(TestStackOpBase):
def init_dtype(self):
self.dtype = np.int64
class TestStackAPIWithLoDTensorArray(unittest.TestCase): class TestStackAPIWithLoDTensorArray(unittest.TestCase):
""" """
Test stack api when the input(x) is a LoDTensorArray. Test stack api when the input(x) is a LoDTensorArray.
...@@ -180,15 +192,15 @@ class TestTensorStackAPIWithLoDTensorArray(unittest.TestCase): ...@@ -180,15 +192,15 @@ class TestTensorStackAPIWithLoDTensorArray(unittest.TestCase):
class API_test(unittest.TestCase): class API_test(unittest.TestCase):
def test_out(self): def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float64') data1 = fluid.layers.data('data1', shape=[1, 2], dtype='float32')
data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float64') data2 = fluid.layers.data('data2', shape=[1, 2], dtype='float32')
data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float64') data3 = fluid.layers.data('data3', shape=[1, 2], dtype='float32')
result_stack = paddle.stack([data1, data2, data3], axis=0) result_stack = paddle.stack([data1, data2, data3], axis=0)
place = paddle.NPUPlace(0) place = paddle.NPUPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
input1 = np.random.random([1, 2]).astype('float64') input1 = np.random.random([1, 2]).astype('float32')
input2 = np.random.random([1, 2]).astype('float64') input2 = np.random.random([1, 2]).astype('float32')
input3 = np.random.random([1, 2]).astype('float64') input3 = np.random.random([1, 2]).astype('float32')
result, = exe.run( result, = exe.run(
feed={"data1": input1, feed={"data1": input1,
"data2": input2, "data2": input2,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册