未验证 提交 3d750f9c 编写于 作者: W Wu Yi 提交者: GitHub

[Feature] Fp16 training for resnet50 (#14850)

* wip

* wip

* wip

* wip for test

* add fp16 tests test=develop

* fix cpu build test=develop

* fix test=develop

* fix py3 tests test=develop

* fix lr_scheduler dtype test=develop

* fix test=dvelop

* test fix ci compile test=develop

* fix build and merge test=develop

* fallback momentumop change to general test=develop
上级 45dd3491
...@@ -359,7 +359,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -359,7 +359,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle? // TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]); auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType();
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0],
out_dtype);
} }
// This assumes the backward generating code will ensure IsScaleLossOp // This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss. // is true only for the op that scale the final scalar loss.
...@@ -662,13 +664,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID( ...@@ -662,13 +664,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name, ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node) const { ir::Node *out_var_node, proto::VarType::Type dtype) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *op_handle = new ScaleLossGradOpHandle( auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx); local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx, dtype);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
......
...@@ -68,7 +68,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -68,7 +68,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateScaleLossGradOp(ir::Graph *result, void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name, const std::string &loss_grad_name,
ir::Node *out_var_node) const; ir::Node *out_var_node,
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; int dst_dev_id) const;
......
...@@ -22,39 +22,66 @@ namespace details { ...@@ -22,39 +22,66 @@ namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope, Scope *scope,
platform::Place place, platform::Place place,
platform::DeviceContext *dev_ctx) platform::DeviceContext *dev_ctx,
proto::VarType::Type dtype)
: OpHandleBase(node), : OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)), coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope), scope_(scope),
place_(place) { place_(place),
out_dtype_(dtype) {
this->SetDeviceContext(place_, dev_ctx); this->SetDeviceContext(place_, dev_ctx);
} }
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
void ScaleLossGradOpHandle::RunImpl() { struct ScaleLossGradFunctor {
// Doesn't wait any event float coeff_;
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_; Tensor *out_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); platform::Place place_;
OpHandleBase *op_handle_;
proto::VarType::Type out_dtype_;
platform::DeviceContext *ctx_;
float *tmp = local_scope.FindVar(var_name) ScaleLossGradFunctor(float coeff, Tensor *out, platform::Place place,
->GetMutable<LoDTensor>() OpHandleBase *op_handle, proto::VarType::Type dtype,
->mutable_data<float>(make_ddim({1}), place_); platform::DeviceContext *ctx)
: coeff_(coeff), out_(out), place_(place), out_dtype_(dtype), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto *out_data = out_->mutable_data<OutT>(place_);
if (platform::is_cpu_place(place_)) { if (platform::is_cpu_place(place_)) {
*tmp = coeff_; *out_data = static_cast<OutT>(coeff_);
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
this->RunAndRecordEvent([&] { OutT cast_coeff = static_cast<OutT>(coeff_);
auto stream = static_cast<platform::CUDADeviceContext *>( auto stream = static_cast<platform::CUDADeviceContext *>(ctx_)->stream();
this->dev_ctxes_.at(place_)) memory::Copy(boost::get<platform::CUDAPlace>(place_), out_data,
->stream(); platform::CPUPlace(), &cast_coeff, SizeOfType(out_dtype_),
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp, stream);
platform::CPUPlace(), &coeff_, sizeof(float), stream);
VLOG(10) << place_ << "RUN Scale loss grad op"; VLOG(10) << place_ << "RUN Scale loss grad op";
});
#endif #endif
} }
}
};
void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *tensor = local_scope.FindVar(var_name)->GetMutable<LoDTensor>();
tensor->Resize(make_ddim({1}));
#ifdef PADDLE_WITH_CUDA
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_,
this->dev_ctxes_.at(place_));
this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); });
#else
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_, nullptr);
framework::VisitDataType(out_dtype_, func);
#endif
} }
std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; } std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; }
......
...@@ -26,8 +26,8 @@ namespace details { ...@@ -26,8 +26,8 @@ namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase { struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope, ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place, platform::Place place, platform::DeviceContext *context,
platform::DeviceContext *context); proto::VarType::Type dtype);
~ScaleLossGradOpHandle() final; ~ScaleLossGradOpHandle() final;
...@@ -40,6 +40,7 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ...@@ -40,6 +40,7 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_; float coeff_;
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
proto::VarType::Type out_dtype_;
}; };
} // namespace details } // namespace details
......
...@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div, elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad, elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -12,19 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,19 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul, elementwise_mul, ops::ElementwiseMulKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
int64_t>); ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_zeros_like_op.h" #include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
...@@ -22,4 +23,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -22,4 +23,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>, ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>); ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/reduce.h> #include <thrust/reduce.h>
#include "paddle/fluid/operators/metrics/accuracy_op.h" #include "paddle/fluid/operators/metrics/accuracy_op.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
namespace paddle { namespace paddle {
...@@ -94,6 +95,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -94,6 +95,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
// FIXME(typhoonzero): types of T is for inference data. // FIXME(typhoonzero): types of T is for inference data.
// label data is always int64 // label data is always int64
REGISTER_OP_CUDA_KERNEL(accuracy, REGISTER_OP_CUDA_KERNEL(
paddle::operators::AccuracyOpCUDAKernel<float>, accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>); paddle::operators::AccuracyOpCUDAKernel<double>,
paddle::operators::AccuracyOpCUDAKernel<paddle::platform::float16>);
...@@ -14,8 +14,11 @@ limitations under the License. */ ...@@ -14,8 +14,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h" #include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>, momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>); ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
...@@ -237,7 +237,8 @@ class SparseMomentumFunctor<T, UseNesterov> { ...@@ -237,7 +237,8 @@ class SparseMomentumFunctor<T, UseNesterov> {
inline HOSTDEVICE void operator()(size_t i) { inline HOSTDEVICE void operator()(size_t i) {
auto row_idx = auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_); math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
// put memory access in register // put memory access in register
const T p = p_[i]; const T p = p_[i];
const T lr = lr_[0]; const T lr = lr_[0];
...@@ -282,7 +283,8 @@ class SparseMomentumFunctor<T, NoNesterov> { ...@@ -282,7 +283,8 @@ class SparseMomentumFunctor<T, NoNesterov> {
inline HOSTDEVICE void operator()(size_t i) { inline HOSTDEVICE void operator()(size_t i) {
auto row_idx = auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_); math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
// put memory access in register // put memory access in register
const T p = p_[i]; const T p = p_[i];
const T lr = lr_[0]; const T lr = lr_[0];
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -150,7 +151,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam, ...@@ -150,7 +151,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
if (k < MaxLength - (*beam)) { if (k < MaxLength - (*beam)) {
topk[k] = topk[k + *beam]; topk[k] = topk[k + *beam];
} else { } else {
topk[k].set(-INFINITY, -1); topk[k].set(-static_cast<T>(INFINITY), -1);
} }
} }
if (!(*is_empty)) { if (!(*is_empty)) {
...@@ -160,7 +161,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam, ...@@ -160,7 +161,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
} }
*max = topk[MaxLength - 1]; *max = topk[MaxLength - 1];
if ((*max).v == -1) *is_empty = true; if ((*max).v == -static_cast<T>(1)) *is_empty = true;
*beam = 0; *beam = 0;
} }
} }
...@@ -181,7 +182,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam, ...@@ -181,7 +182,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
if (k < MaxLength - *beam) { if (k < MaxLength - *beam) {
topk[k] = topk[k + *beam]; topk[k] = topk[k + *beam];
} else { } else {
topk[k].set(-INFINITY, -1); topk[k].set(-static_cast<T>(INFINITY), -1);
} }
} }
if (!(*is_empty)) { if (!(*is_empty)) {
...@@ -278,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, ...@@ -278,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
bool firststep = true; bool firststep = true;
for (int j = 0; j < MaxLength; j++) { for (int j = 0; j < MaxLength; j++) {
topk[j].set(-INFINITY, -1); topk[j].set(-static_cast<T>(INFINITY), -1);
} }
while (top_num) { while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>( ThreadGetTopK<T, MaxLength, BlockSize>(
...@@ -362,5 +363,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -362,5 +363,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(
paddle::operators::TopkOpCUDAKernel<double>); top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>);
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#define NCCL_ID_VARNAME "NCCLID" #define NCCL_ID_VARNAME "NCCLID"
...@@ -38,6 +39,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) { ...@@ -38,6 +39,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclInt; return ncclInt;
} else if (type == framework::proto::VarType::INT64) { } else if (type == framework::proto::VarType::INT64) {
return ncclInt64; return ncclInt64;
} else if (type == framework::proto::VarType::FP16) {
return ncclFloat16;
} else { } else {
PADDLE_THROW("Not supported"); PADDLE_THROW("Not supported");
} }
......
...@@ -44,6 +44,8 @@ class DataToLoDTensorConverter(object): ...@@ -44,6 +44,8 @@ class DataToLoDTensorConverter(object):
self.dtype = 'int64' self.dtype = 'int64'
elif dtype == core.VarDesc.VarType.FP64: elif dtype == core.VarDesc.VarType.FP64:
self.dtype = 'float64' self.dtype = 'float64'
elif dtype == core.VarDesc.VarType.FP16:
self.dtype = 'float16'
elif dtype == core.VarDesc.VarType.INT32: elif dtype == core.VarDesc.VarType.INT32:
self.dtype = 'int32' self.dtype = 'int32'
elif dtype == core.VarDesc.VarType.UINT8: elif dtype == core.VarDesc.VarType.UINT8:
......
...@@ -18,6 +18,7 @@ from . import framework ...@@ -18,6 +18,7 @@ from . import framework
import numpy as np import numpy as np
import contextlib import contextlib
from .core import VarDesc from .core import VarDesc
from . import unique_name
__all__ = [ __all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear', 'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
...@@ -207,16 +208,39 @@ class UniformInitializer(Initializer): ...@@ -207,16 +208,39 @@ class UniformInitializer(Initializer):
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(['gaussian_random', 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
op = block._prepend_op( op = block._prepend_op(
type="uniform_random", type="uniform_random",
outputs={"Out": var}, outputs={"Out": out_var},
attrs={ attrs={
"shape": var.shape, "shape": var.shape,
"dtype": int(var.dtype), "dtype": out_dtype,
"min": self._low, "min": self._low,
"max": self._high, "max": self._high,
"seed": self._seed "seed": self._seed
}) })
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op var.op = op
return op return op
...@@ -261,17 +285,39 @@ class NormalInitializer(Initializer): ...@@ -261,17 +285,39 @@ class NormalInitializer(Initializer):
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(['gaussian_random', 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
op = block._prepend_op( op = block._prepend_op(
type="gaussian_random", type="gaussian_random",
outputs={"Out": var}, outputs={"Out": out_var},
attrs={ attrs={
"shape": var.shape, "shape": var.shape,
"dtype": int(var.dtype), "dtype": out_dtype,
"mean": self._mean, "mean": self._mean,
"std": self._std_dev, "std": self._std_dev,
"seed": self._seed, "seed": self._seed,
"use_mkldnn": False "use_mkldnn": False
}) })
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op var.op = op
return op return op
......
...@@ -63,6 +63,8 @@ def noam_decay(d_model, warmup_steps): ...@@ -63,6 +63,8 @@ def noam_decay(d_model, warmup_steps):
Returns: Returns:
The decayed learning rate. The decayed learning rate.
""" """
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter(1) global_step = _decay_step_counter(1)
...@@ -72,6 +74,8 @@ def noam_decay(d_model, warmup_steps): ...@@ -72,6 +74,8 @@ def noam_decay(d_model, warmup_steps):
return lr_value return lr_value
return _lr_schedule
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
""" """
...@@ -109,6 +113,8 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -109,6 +113,8 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
""" """
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter() global_step = _decay_step_counter()
...@@ -119,6 +125,8 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -119,6 +125,8 @@ def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
return decayed_lr return decayed_lr
return _lr_schedule
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
"""Applies natural exponential decay to the initial learning rate. """Applies natural exponential decay to the initial learning rate.
...@@ -138,6 +146,8 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -138,6 +146,8 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
Returns: Returns:
The decayed learning rate The decayed learning rate
""" """
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter() global_step = _decay_step_counter()
...@@ -148,6 +158,8 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -148,6 +158,8 @@ def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
return decayed_lr return decayed_lr
return _lr_schedule
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
""" """
...@@ -184,6 +196,8 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -184,6 +196,8 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
staircase=True)) staircase=True))
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
""" """
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter() global_step = _decay_step_counter()
...@@ -195,6 +209,8 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False): ...@@ -195,6 +209,8 @@ def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
return decayed_lr return decayed_lr
return _lr_schedule
def polynomial_decay(learning_rate, def polynomial_decay(learning_rate,
decay_steps, decay_steps,
...@@ -224,15 +240,17 @@ def polynomial_decay(learning_rate, ...@@ -224,15 +240,17 @@ def polynomial_decay(learning_rate,
Returns: Returns:
Variable: The decayed learning rate Variable: The decayed learning rate
""" """
def _lr_schedule(dtype, decay_steps=decay_steps):
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
global_step = _decay_step_counter() global_step = _decay_step_counter()
if cycle: if cycle:
div_res = ops.ceil(global_step / decay_steps) div_res = ops.ceil(global_step / decay_steps)
zero_var = tensor.fill_constant( zero_var = tensor.fill_constant(
shape=[1], dtype='float32', value=0.0) shape=[1], dtype=dtype, value=0.0)
one_var = tensor.fill_constant( one_var = tensor.fill_constant(
shape=[1], dtype='float32', value=1.0) shape=[1], dtype=dtype, value=1.0)
with control_flow.Switch() as switch: with control_flow.Switch() as switch:
with switch.case(global_step == zero_var): with switch.case(global_step == zero_var):
...@@ -240,13 +258,16 @@ def polynomial_decay(learning_rate, ...@@ -240,13 +258,16 @@ def polynomial_decay(learning_rate,
decay_steps = decay_steps * div_res decay_steps = decay_steps * div_res
else: else:
decay_steps_var = tensor.fill_constant( decay_steps_var = tensor.fill_constant(
shape=[1], dtype='float32', value=float(decay_steps)) shape=[1], dtype=dtype, value=float(decay_steps))
global_step = nn.elementwise_min(x=global_step, y=decay_steps_var) global_step = nn.elementwise_min(
x=global_step, y=decay_steps_var)
decayed_lr = (learning_rate - end_learning_rate) * \ decayed_lr = (learning_rate - end_learning_rate) * \
((1 - global_step / decay_steps) ** power) + end_learning_rate ((1 - global_step / decay_steps) ** power) + end_learning_rate
return decayed_lr return decayed_lr
return _lr_schedule
def piecewise_decay(boundaries, values): def piecewise_decay(boundaries, values):
"""Applies piecewise decay to the initial learning rate. """Applies piecewise decay to the initial learning rate.
...@@ -273,6 +294,8 @@ def piecewise_decay(boundaries, values): ...@@ -273,6 +294,8 @@ def piecewise_decay(boundaries, values):
""" """
def _lr_schedule(dtype):
with default_main_program()._lr_schedule_guard(): with default_main_program()._lr_schedule_guard():
if len(values) - len(boundaries) != 1: if len(values) - len(boundaries) != 1:
raise ValueError("len(values) - len(boundaries) should be 1") raise ValueError("len(values) - len(boundaries) should be 1")
...@@ -306,6 +329,8 @@ def piecewise_decay(boundaries, values): ...@@ -306,6 +329,8 @@ def piecewise_decay(boundaries, values):
return lr return lr
return _lr_schedule
def append_LARS(params_grads, learning_rate, weight_decay): def append_LARS(params_grads, learning_rate, weight_decay):
""" """
......
...@@ -2798,6 +2798,10 @@ def batch_norm(input, ...@@ -2798,6 +2798,10 @@ def batch_norm(input,
helper = LayerHelper('batch_norm', **locals()) helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
# use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16:
dtype = core.VarDesc.VarType.FP32
input_shape = input.shape input_shape = input.shape
if data_layout == 'NCHW': if data_layout == 'NCHW':
channel_num = input_shape[1] channel_num = input_shape[1]
...@@ -2832,7 +2836,7 @@ def batch_norm(input, ...@@ -2832,7 +2836,7 @@ def batch_norm(input,
trainable=False, trainable=False,
do_model_average=do_model_average_for_mean_and_var), do_model_average=do_model_average_for_mean_and_var),
shape=param_shape, shape=param_shape,
dtype=input.dtype) dtype=dtype)
mean.stop_gradient = True mean.stop_gradient = True
variance = helper.create_parameter( variance = helper.create_parameter(
...@@ -2842,7 +2846,7 @@ def batch_norm(input, ...@@ -2842,7 +2846,7 @@ def batch_norm(input,
trainable=False, trainable=False,
do_model_average=do_model_average_for_mean_and_var), do_model_average=do_model_average_for_mean_and_var),
shape=param_shape, shape=param_shape,
dtype=input.dtype) dtype=dtype)
variance.stop_gradient = True variance.stop_gradient = True
# create output # create output
......
...@@ -50,17 +50,21 @@ class Optimizer(object): ...@@ -50,17 +50,21 @@ class Optimizer(object):
def __init__(self, learning_rate, regularization=None, name=None): def __init__(self, learning_rate, regularization=None, name=None):
if not isinstance(learning_rate, float) and \ if not isinstance(learning_rate, float) and \
not isinstance(learning_rate, framework.Variable): not isinstance(learning_rate, framework.Variable) and \
raise TypeError("learning rate should be float or Variable") not callable(learning_rate):
raise TypeError(
"learning rate should be float or Variable or callable(dtype)")
self._name = name self._name = name
self.regularization = regularization self.regularization = regularization
self._learning_rate = learning_rate self._learning_rate = learning_rate
# the learning rate type should be inferenced from loss # the learning rate type should be inferenced from loss
self._dtype = None self._dtype = None
# each program should have a independent learning rate # each program should have a independent learning rate
# program -> Variable(learning_rate) # program -> Variable(learning_rate) or:
# program -> callable(return learning_rate Variable)
self._learning_rate_map = dict() self._learning_rate_map = dict()
if isinstance(self._learning_rate, framework.Variable): if isinstance(self._learning_rate, framework.Variable) or \
callable(self._learning_rate):
self._learning_rate_map[framework.default_main_program( self._learning_rate_map[framework.default_main_program(
)] = self._learning_rate )] = self._learning_rate
# Dictionary of accumulators. Some optimizer subclasses need to # Dictionary of accumulators. Some optimizer subclasses need to
...@@ -75,6 +79,11 @@ class Optimizer(object): ...@@ -75,6 +79,11 @@ class Optimizer(object):
if isinstance(lr, framework.Variable): if isinstance(lr, framework.Variable):
return return
elif callable(lr):
dtype = 'float32' if self._dtype is None else self._dtype
self._learning_rate_map[framework.default_main_program()] = lr(
dtype)
return
else: else:
if not isinstance(self._learning_rate, float): if not isinstance(self._learning_rate, float):
raise TypeError( raise TypeError(
......
...@@ -370,6 +370,8 @@ class OpTest(unittest.TestCase): ...@@ -370,6 +370,8 @@ class OpTest(unittest.TestCase):
return [place] return [place]
else: else:
return [] return []
else:
return []
places = [fluid.CPUPlace()] places = [fluid.CPUPlace()]
cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type)\ if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type)\
......
...@@ -22,8 +22,10 @@ from op_test import OpTest ...@@ -22,8 +22,10 @@ from op_test import OpTest
class TestAccuracyOp(OpTest): class TestAccuracyOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "accuracy" self.op_type = "accuracy"
self.dtype = np.float32
self.init_dtype()
n = 8192 n = 8192
infer = np.random.random((n, 1)).astype("float32") infer = np.random.random((n, 1)).astype(self.dtype)
indices = np.random.randint(0, 2, (n, 1)) indices = np.random.randint(0, 2, (n, 1))
label = np.random.randint(0, 2, (n, 1)) label = np.random.randint(0, 2, (n, 1))
self.inputs = {'Out': infer, 'Indices': indices, "Label": label} self.inputs = {'Out': infer, 'Indices': indices, "Label": label}
...@@ -34,14 +36,25 @@ class TestAccuracyOp(OpTest): ...@@ -34,14 +36,25 @@ class TestAccuracyOp(OpTest):
num_correct += 1 num_correct += 1
break break
self.outputs = { self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype("float32"), 'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype),
'Correct': np.array([num_correct]).astype("int32"), 'Correct': np.array([num_correct]).astype("int32"),
'Total': np.array([n]).astype("int32") 'Total': np.array([n]).astype("int32")
} }
def init_dtype(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestAccuracyOpFp16(TestAccuracyOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,14 +21,16 @@ from op_test import OpTest ...@@ -21,14 +21,16 @@ from op_test import OpTest
class ElementwiseDivOp(OpTest): class ElementwiseDivOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_div" self.op_type = "elementwise_div"
self.dtype = np.float32
self.init_dtype()
""" Warning """ Warning
CPU gradient check error! CPU gradient check error!
'X': np.random.random((32,84)).astype("float32"), 'X': np.random.random((32,84)).astype("float32"),
'Y': np.random.random((32,84)).astype("float32") 'Y': np.random.random((32,84)).astype("float32")
""" """
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"), 'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32") 'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
} }
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
...@@ -46,6 +48,9 @@ class ElementwiseDivOp(OpTest): ...@@ -46,6 +48,9 @@ class ElementwiseDivOp(OpTest):
self.check_grad( self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y')) ['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
def init_dtype(self):
pass
class TestElementwiseDivOp_scalar(ElementwiseDivOp): class TestElementwiseDivOp_scalar(ElementwiseDivOp):
def setUp(self): def setUp(self):
...@@ -126,5 +131,21 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp): ...@@ -126,5 +131,21 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
} }
class TestElementwiseDivOpFp16(ElementwiseDivOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=1, no_grad_set=set('Y'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -135,5 +135,10 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): ...@@ -135,5 +135,10 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
} }
class TestElementwiseMulOpFp16(ElementwiseMulOp):
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -22,12 +22,22 @@ from op_test import OpTest ...@@ -22,12 +22,22 @@ from op_test import OpTest
class TestFillZerosLikeOp(OpTest): class TestFillZerosLikeOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fill_zeros_like" self.op_type = "fill_zeros_like"
self.inputs = {'X': np.random.random((219, 232)).astype("float32")} self.dtype = np.float32
self.init_dtype()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.outputs = {'Out': np.zeros_like(self.inputs["X"])} self.outputs = {'Out': np.zeros_like(self.inputs["X"])}
def init_dtype(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestFillZerosLikeOpFp16(TestFillZerosLikeOp):
def init_dtype(self):
self.dtype = np.float16
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -97,7 +97,7 @@ class TestLearningRateDecay(unittest.TestCase): ...@@ -97,7 +97,7 @@ class TestLearningRateDecay(unittest.TestCase):
startup_prog = fluid.Program() startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
decayed_lr = fluid_decay_fn(**kwargs) decayed_lr = fluid_decay_fn(**kwargs)("float32")
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
......
...@@ -24,11 +24,13 @@ from op_test import OpTest ...@@ -24,11 +24,13 @@ from op_test import OpTest
class TestMomentumOp1(OpTest): class TestMomentumOp1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "momentum" self.op_type = "momentum"
self.dtype = np.float32
self.init_dtype()
param = np.random.random((123, 321)).astype("float32") param = np.random.random((123, 321)).astype(self.dtype)
grad = np.random.random((123, 321)).astype("float32") grad = np.random.random((123, 321)).astype(self.dtype)
velocity = np.zeros((123, 321)).astype("float32") velocity = np.zeros((123, 321)).astype(self.dtype)
learning_rate = np.array([0.001]).astype("float32") learning_rate = np.array([0.001]).astype(self.dtype)
mu = 0.0001 mu = 0.0001
use_nesterov = False use_nesterov = False
...@@ -50,10 +52,21 @@ class TestMomentumOp1(OpTest): ...@@ -50,10 +52,21 @@ class TestMomentumOp1(OpTest):
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out} self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def init_dtype(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestMomentumOpFp16(TestMomentumOp1):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
class TestMomentumOp2(OpTest): class TestMomentumOp2(OpTest):
'''Test Momentum with default values for attributes '''Test Momentum with default values for attributes
''' '''
......
...@@ -23,8 +23,11 @@ class TestTopkOp(OpTest): ...@@ -23,8 +23,11 @@ class TestTopkOp(OpTest):
def setUp(self): def setUp(self):
self.set_args() self.set_args()
self.op_type = "top_k" self.op_type = "top_k"
self.dtype = np.float32
self.init_dtype()
k = self.top_k k = self.top_k
input = np.random.random((self.row, k)).astype("float32") input = np.random.random((self.row, k)).astype(self.dtype)
output = np.ndarray((self.row, k)) output = np.ndarray((self.row, k))
indices = np.ndarray((self.row, k)).astype("int64") indices = np.ndarray((self.row, k)).astype("int64")
...@@ -38,6 +41,9 @@ class TestTopkOp(OpTest): ...@@ -38,6 +41,9 @@ class TestTopkOp(OpTest):
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
def init_dtype(self):
pass
def set_args(self): def set_args(self):
self.row = 32 self.row = 32
self.top_k = 1 self.top_k = 1
...@@ -46,6 +52,11 @@ class TestTopkOp(OpTest): ...@@ -46,6 +52,11 @@ class TestTopkOp(OpTest):
self.check_output() self.check_output()
class TestTopkOpFp16(TestTopkOp):
def init_dtype(self):
self.dtype = np.float16
class TestTopkOp3d(OpTest): class TestTopkOp3d(OpTest):
def setUp(self): def setUp(self):
self.op_type = "top_k" self.op_type = "top_k"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册