提交 5631fc08 编写于 作者: S shippingwang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into shufflechannel

......@@ -48,10 +48,10 @@ if(WITH_GPU)
nv_library(tensor SRCS tensor.cc .tensor_util.cu DEPS place memory data_type device_context)
add_dependencies(tensor tensor_util)
else()
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context )
endif(WIN32)
else()
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type device_context)
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type device_context )
endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
......
......@@ -355,7 +355,9 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]);
auto out_dtype = all_vars_.at(loss_grad_name)->GetDataType();
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0],
out_dtype);
}
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
......@@ -658,13 +660,13 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node) const {
ir::Node *out_var_node, proto::VarType::Type dtype) const {
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx);
local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx, dtype);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
......
......@@ -68,7 +68,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name,
ir::Node *out_var_node) const;
ir::Node *out_var_node,
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const;
......
......@@ -22,39 +22,66 @@ namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
Scope *scope,
platform::Place place,
platform::DeviceContext *dev_ctx)
platform::DeviceContext *dev_ctx,
proto::VarType::Type dtype)
: OpHandleBase(node),
coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {
place_(place),
out_dtype_(dtype) {
this->SetDeviceContext(place_, dev_ctx);
}
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
struct ScaleLossGradFunctor {
float coeff_;
Tensor *out_;
platform::Place place_;
OpHandleBase *op_handle_;
proto::VarType::Type out_dtype_;
platform::DeviceContext *ctx_;
ScaleLossGradFunctor(float coeff, Tensor *out, platform::Place place,
OpHandleBase *op_handle, proto::VarType::Type dtype,
platform::DeviceContext *ctx)
: coeff_(coeff), out_(out), place_(place), out_dtype_(dtype), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto *out_data = out_->mutable_data<OutT>(place_);
if (platform::is_cpu_place(place_)) {
*out_data = static_cast<OutT>(coeff_);
} else {
#ifdef PADDLE_WITH_CUDA
OutT cast_coeff = static_cast<OutT>(coeff_);
auto stream = static_cast<platform::CUDADeviceContext *>(ctx_)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), out_data,
platform::CPUPlace(), &cast_coeff, SizeOfType(out_dtype_),
stream);
VLOG(10) << place_ << "RUN Scale loss grad op";
#endif
}
}
};
void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
float *tmp = local_scope.FindVar(var_name)
->GetMutable<LoDTensor>()
->mutable_data<float>(make_ddim({1}), place_);
auto *tensor = local_scope.FindVar(var_name)->GetMutable<LoDTensor>();
tensor->Resize(make_ddim({1}));
if (platform::is_cpu_place(place_)) {
*tmp = coeff_;
} else {
#ifdef PADDLE_WITH_CUDA
this->RunAndRecordEvent([&] {
auto stream = static_cast<platform::CUDADeviceContext *>(
this->dev_ctxes_.at(place_))
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
VLOG(10) << place_ << "RUN Scale loss grad op";
});
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_,
this->dev_ctxes_.at(place_));
this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); });
#else
ScaleLossGradFunctor func(coeff_, tensor, place_, this, out_dtype_, nullptr);
framework::VisitDataType(out_dtype_, func);
#endif
}
}
std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; }
......
......@@ -26,8 +26,8 @@ namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase {
ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *context);
platform::Place place, platform::DeviceContext *context,
proto::VarType::Type dtype);
~ScaleLossGradOpHandle() final;
......@@ -40,6 +40,7 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_;
Scope *scope_;
platform::Place place_;
proto::VarType::Type out_dtype_;
};
} // namespace details
......
......@@ -82,6 +82,10 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
AddAttr<std::string>(OpNamescopeAttrName(), "Operator name with namesope.")
.SetDefault("");
AddAttr<std::vector<std::string>>(OpCreationCallstackAttrName(),
"Callstack for Op Creatation.")
.SetDefault({});
Validate();
}
......
......@@ -47,6 +47,7 @@ class OpProtoAndCheckerMaker {
static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpNamescopeAttrName() { return "op_namescope"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
......@@ -16,10 +16,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/transfer_scope_cache.h"
......@@ -157,27 +162,59 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
try {
if (VLOG_IS_ON(4)) {
VLOG(4) << place << " " << DebugStringEx(&scope);
}
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place);
PADDLE_THROW("Cannot run operator on place %s", place);
#else
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id);
#endif
}
}
// The profile has a process-wide mutex, results in serious performance issue
// in concurrency scenerio. Here use an `if` to fix this issue.
// Please not remove the `if`, ask @Superjomn if there are any concern.
if (platform::IsProfileEnabled()) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
} else {
RunImpl(scope, place);
// The profile has a process-wide mutex, results in serious performance
// issue
// in concurrency scenerio. Here use an `if` to fix this issue.
// Please not remove the `if`, ask @Superjomn if there are any concern.
if (platform::IsProfileEnabled()) {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
} else {
RunImpl(scope, place);
}
if (VLOG_IS_ON(3)) {
VLOG(3) << place << " " << DebugStringEx(&scope);
}
} catch (platform::EnforceNotMet exception) {
if (Attrs().count("sub_block") != 0) {
throw exception;
}
auto& callstack = Attr<std::vector<std::string>>(
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (callstack.empty()) {
throw exception;
}
std::ostringstream sout;
sout << "Invoke operator " << Type() << " error.\n";
sout << "Python Callstacks: \n";
for (auto& line : callstack) {
sout << line;
}
sout << "C++ Callstacks: \n";
sout << exception.err_str_;
exception.err_str_ = sout.str();
throw exception;
} catch (...) {
std::rethrow_exception(std::current_exception());
}
VLOG(3) << place << " " << DebugStringEx(&scope);
}
bool OperatorBase::HasInputs(const std::string& name) const {
......
......@@ -28,8 +28,7 @@ void Tensor::check_memory_size() const {
"or maybe the required data-type mismatches the data already stored.");
}
Tensor::Tensor(std::type_index type)
: type_(framework::ToDataType(type)), offset_(0) {}
Tensor::Tensor(const proto::VarType::Type& dtype) : type_(dtype), offset_(0) {}
size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_;
......
......@@ -69,7 +69,7 @@ class Tensor {
public:
Tensor() : type_(proto::VarType::FP32), offset_(0) {}
explicit Tensor(std::type_index type);
explicit Tensor(const proto::VarType::Type&);
/*! Return a pointer to mutable memory block. */
template <typename T>
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/temporary_allocator.h"
namespace paddle {
namespace framework {
......@@ -151,5 +152,26 @@ void TensorToVector(const Tensor& src, std::vector<T>* dst) {
src_ptr, size);
}
template <typename T>
paddle::framework::Tensor GetTensor(
memory::allocation::AllocationPtr temp_allocation_ptr,
const framework::DDim& dim) {
auto& deleter = temp_allocation_ptr.get_deleter();
auto* allocation_ptr = temp_allocation_ptr.release();
auto shared_allocation =
std::shared_ptr<memory::allocation::Allocation>(allocation_ptr, deleter);
PADDLE_ENFORCE(
dynamic_cast<platform::TemporaryAllocation*>(allocation_ptr) != nullptr,
"The AllocationPtr must be TemporaryAllocation.");
PADDLE_ENFORCE_EQ(allocation_ptr->size(),
framework::product(dim) * sizeof(T));
paddle::framework::Tensor temp_tensor(
framework::ToDataType(std::type_index(typeid(T))));
temp_tensor.Resize(dim);
temp_tensor.ResetHolder(std::move(shared_allocation));
return temp_tensor;
}
} // namespace framework
} // namespace paddle
......@@ -75,6 +75,11 @@ set(LAC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lac")
download_model_and_data(${LAC_INSTALL_DIR} "lac_model.tar.gz" "lac_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_lac ${LAC_INSTALL_DIR} analyzer_lac_tester.cc)
# MM DNN
set(MM_DNN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mm_dnn")
download_model_and_data(${MM_DNN_INSTALL_DIR} "MM_DNN_model.tar.gz" "MM_DNN_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_mm_dnn ${MM_DNN_INSTALL_DIR} analyzer_mm_dnn_tester.cc)
# text_classification
set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification")
download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz")
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
namespace inference {
using contrib::AnalysisConfig;
struct DataRecord {
std::vector<std::vector<int64_t>> query_data_all, title_data_all;
std::vector<size_t> lod1, lod2;
size_t batch_iter{0};
size_t batch_size{1};
size_t num_samples; // total number of samples
DataRecord() = default;
explicit DataRecord(const std::string &path, int batch_size = 1)
: batch_size(batch_size) {
Load(path);
}
DataRecord NextBatch() {
DataRecord data;
size_t batch_end = batch_iter + batch_size;
// NOTE skip the final batch, if no enough data is provided.
if (batch_end <= query_data_all.size()) {
data.query_data_all.assign(query_data_all.begin() + batch_iter,
query_data_all.begin() + batch_end);
data.title_data_all.assign(title_data_all.begin() + batch_iter,
title_data_all.begin() + batch_end);
// Prepare LoDs
data.lod1.push_back(0);
data.lod2.push_back(0);
CHECK(!data.query_data_all.empty());
CHECK(!data.title_data_all.empty());
CHECK_EQ(data.query_data_all.size(), data.title_data_all.size());
for (size_t j = 0; j < data.query_data_all.size(); j++) {
// calculate lod
data.lod1.push_back(data.lod1.back() + data.query_data_all[j].size());
data.lod2.push_back(data.lod2.back() + data.title_data_all[j].size());
}
}
batch_iter += batch_size;
return data;
}
void Load(const std::string &path) {
std::ifstream file(path);
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
num_lines++;
std::vector<std::string> data;
split(line, '\t', &data);
// load query data
std::vector<int64_t> query_data;
split_to_int64(data[0], ' ', &query_data);
// load title data
std::vector<int64_t> title_data;
split_to_int64(data[1], ' ', &title_data);
query_data_all.push_back(std::move(query_data));
title_data_all.push_back(std::move(title_data));
}
num_samples = num_lines;
}
};
void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
int batch_size) {
PaddleTensor lod_query_tensor, lod_title_tensor;
lod_query_tensor.name = "left";
lod_title_tensor.name = "right";
auto one_batch = data->NextBatch();
int size1 = one_batch.lod1[one_batch.lod1.size() - 1]; // token batch size
int size2 = one_batch.lod2[one_batch.lod2.size() - 1]; // token batch size
lod_query_tensor.shape.assign({size1, 1});
lod_query_tensor.lod.assign({one_batch.lod1});
lod_title_tensor.shape.assign({size2, 1});
lod_title_tensor.lod.assign({one_batch.lod2});
// assign data
TensorAssignData<int64_t>(&lod_query_tensor, one_batch.query_data_all);
TensorAssignData<int64_t>(&lod_title_tensor, one_batch.title_data_all);
// Set inputs.
input_slots->assign({lod_query_tensor, lod_title_tensor});
for (auto &tensor : *input_slots) {
tensor.dtype = PaddleDType::INT64;
}
}
void SetConfig(contrib::AnalysisConfig *cfg) {
cfg->model_dir = FLAGS_infer_model;
cfg->use_gpu = false;
cfg->device = 0;
cfg->specify_input_name = true;
cfg->enable_ir_optim = true;
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
std::vector<PaddleTensor> input_slots;
int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1;
LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size;
for (int bid = 0; bid < epoch; ++bid) {
PrepareInputs(&input_slots, &data, FLAGS_batch_size);
(*inputs).emplace_back(input_slots);
}
}
// Easy for profiling independently.
TEST(Analyzer_MM_DNN, profile) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
PADDLE_ENFORCE_EQ(outputs.size(), 2UL);
for (auto &output : outputs) {
size_t size = GetSize(output);
PADDLE_ENFORCE_GT(size, 0);
float *result = static_cast<float *>(output.data.data());
// output is probability, which is in (-1, 1).
for (size_t i = 0; i < size; i++) {
EXPECT_GT(result[i], -1);
EXPECT_LT(result[i], 1);
}
}
}
}
// Check the fuse status
TEST(Analyzer_MM_DNN, fuse_statis) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
}
// Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_MM_DNN, compare) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Compare Deterministic result
TEST(Analyzer_MM_DNN, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference
} // namespace paddle
......@@ -18,11 +18,11 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/create_tensor_with_allocationptr.h"
namespace paddle {
namespace operators {
......@@ -161,10 +161,7 @@ class GemmConvKernel : public framework::OpKernel<T> {
auto tmp_allocation_ptr =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
framework::product(col_shape) * sizeof(T));
Tensor tep_tensor =
platform::GetTensor<T>(std::move(tmp_allocation_ptr), col_shape);
col.ShareDataWith(tep_tensor);
col = framework::GetTensor<T>(std::move(tmp_allocation_ptr), col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
......@@ -299,10 +296,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
auto tmp_allocation_ptr =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
framework::product(col_shape) * sizeof(T));
Tensor tep_tensor =
platform::GetTensor<T>(std::move(tmp_allocation_ptr), col_shape);
col.ShareDataWith(tep_tensor);
col = framework::GetTensor<T>(std::move(tmp_allocation_ptr), col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
......
......@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
elementwise_div,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
......
......@@ -12,19 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CUDADeviceContext, int64_t>);
elementwise_mul, ops::ElementwiseMulKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_zeros_like_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
......@@ -22,4 +23,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, float>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, double>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::FillZerosLikeKernel<paddle::platform::CUDADeviceContext, bool>);
......@@ -131,8 +131,9 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
int in_col = input[0].numel() / in_row;
int out_row = in_row, out_col = 0;
std::vector<T*> inputs_data(in_num);
std::vector<const T*> inputs_data;
std::vector<int> inputs_col(in_num + 1);
inputs_data.reserve(in_num);
inputs_col[0] = 0;
bool sameShape = true;
......@@ -143,7 +144,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
}
out_col += t_cols;
inputs_col[i + 1] = out_col;
inputs_data[i] = const_cast<T*>(input[i].data<T>());
inputs_data.emplace_back(input[i].data<T>());
}
// computation
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/reduce.h>
#include "paddle/fluid/operators/metrics/accuracy_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
......@@ -94,6 +95,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
// FIXME(typhoonzero): types of T is for inference data.
// label data is always int64
REGISTER_OP_CUDA_KERNEL(accuracy,
paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
accuracy, paddle::operators::AccuracyOpCUDAKernel<float>,
paddle::operators::AccuracyOpCUDAKernel<double>,
paddle::operators::AccuracyOpCUDAKernel<paddle::platform::float16>);
......@@ -14,8 +14,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>);
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
......@@ -237,7 +237,8 @@ class SparseMomentumFunctor<T, UseNesterov> {
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
......@@ -282,7 +283,8 @@ class SparseMomentumFunctor<T, NoNesterov> {
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_]
: static_cast<T>(0);
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -150,7 +151,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
if (k < MaxLength - (*beam)) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(-INFINITY, -1);
topk[k].set(-static_cast<T>(INFINITY), -1);
}
}
if (!(*is_empty)) {
......@@ -160,7 +161,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
}
*max = topk[MaxLength - 1];
if ((*max).v == -1) *is_empty = true;
if ((*max).v == -static_cast<T>(1)) *is_empty = true;
*beam = 0;
}
}
......@@ -181,7 +182,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int* beam,
if (k < MaxLength - *beam) {
topk[k] = topk[k + *beam];
} else {
topk[k].set(-INFINITY, -1);
topk[k].set(-static_cast<T>(INFINITY), -1);
}
}
if (!(*is_empty)) {
......@@ -278,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
bool firststep = true;
for (int j = 0; j < MaxLength; j++) {
topk[j].set(-INFINITY, -1);
topk[j].set(-static_cast<T>(INFINITY), -1);
}
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(
......@@ -362,5 +363,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/temporary_allocator.h"
namespace paddle {
namespace platform {
template <typename T>
paddle::framework::Tensor GetTensor(
memory::allocation::AllocationPtr temp_allocation_ptr,
const framework::DDim &dim) {
auto &deleter = temp_allocation_ptr.get_deleter();
auto *allocation_ptr = temp_allocation_ptr.release();
auto shared_allocation =
std::shared_ptr<memory::allocation::Allocation>(allocation_ptr, deleter);
PADDLE_ENFORCE(dynamic_cast<TemporaryAllocation *>(allocation_ptr) != nullptr,
"The AllocationPtr must be TemporaryAllocation.");
PADDLE_ENFORCE_EQ(allocation_ptr->size(),
framework::product(dim) * sizeof(T));
paddle::framework::Tensor temp_tensor(std::type_index(typeid(T)));
temp_tensor.Resize(dim);
temp_tensor.ResetHolder(std::move(shared_allocation));
return temp_tensor;
}
} // namespace platform
} // namespace paddle
......@@ -256,10 +256,11 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
<< ", CUDA Capability: " << compute_capability_
<< ", Driver Version: " << driver_version_ / 1000
<< ", Driver API Version: " << driver_version_ / 1000
<< "." << (driver_version_ % 100) / 10
<< ", Runtime Version: " << runtime_version_ / 1000
<< "." << (runtime_version_ % 100) / 10;
<< ", Runtime API Version: "
<< runtime_version_ / 1000 << "."
<< (runtime_version_ % 100) / 10;
size_t cudnn_dso_ver = dynload::cudnnGetVersion();
LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
<< ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
......
......@@ -41,7 +41,28 @@ limitations under the License. */
namespace paddle {
namespace platform {
/*! \brief device temporary allocator singleton */
/*! \brief device temporary allocator singleton.
*
* Some operator needs temporary memory during computation, for example,
* conv_gemm, which needs use col to store the result of im2col. If we
* create a stack memory which is used by CUDA Kernel, before the
* Computation(...) returns, we should add ctx->Wait(), because the
* execution of CUDA is async, if there doesn't have ctx->Wait(),
* the temporary memory will be released before the CUDA Kernel uses
* it.
*
* DeviceTemporaryAllocator is a singleton, which contains a
* `TemporaryAllocator` for each <Place, Stream>. And the TemporaryAllocator
* contains a temp_allocation_queue which is used to store the temporary
* allocations. The allocation, which is allocated by TemporaryAllocator,
* is a unique_ptr, and when it is not held by any variable, it will be
* pushed into the temp_allocation_queue. There are two opportunities to free
* the allocations of temp_allocation_queue:
* - when the Stream calls cudaStreamSynchronize;
* - when the allocation size of opportunities exceeds a certain threshold
* (defined by FLAGS_limit_of_temporary_allocation).
*
* */
class DeviceTemporaryAllocator {
public:
static DeviceTemporaryAllocator& Instance() {
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#define NCCL_ID_VARNAME "NCCLID"
......@@ -38,6 +39,8 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
return ncclInt;
} else if (type == framework::proto::VarType::INT64) {
return ncclInt64;
} else if (type == framework::proto::VarType::FP16) {
return ncclFloat16;
} else {
PADDLE_THROW("Not supported");
}
......
......@@ -29,6 +29,19 @@ class TemporaryAllocation : public memory::allocation::Allocation {
memory::allocation::AllocationPtr underlying_allocation_;
};
/*! \brief the TemporaryAllocator is used to alloc the temporary allocation
* which used by CUDA's async operation.
*
* The TemporaryAllocator contains a temp_allocation_queue which
* is used to store the temporary allocations. The allocation, which is
* allocated by TemporaryAllocator, is a unique_ptr, and when it is not held
* by any variable, it will be pushed into the temp_allocation_queue.
*
* There is one opportunity to free the allocations of temp_allocation_queue:
* - when the allocation size of opportunities exceeds a certain threshold
* (defined by FLAGS_limit_of_temporary_allocation).
*
* */
class TemporaryAllocator : public memory::allocation::Allocator {
public:
explicit TemporaryAllocator(platform::Place place);
......
......@@ -14,8 +14,7 @@
#include "paddle/fluid/platform/temporary_allocator.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/create_tensor_with_allocationptr.h"
#include "paddle/fluid/framework/tensor_util.h"
DECLARE_double(limit_of_temporary_allocation);
namespace paddle {
......@@ -47,6 +46,7 @@ TEST(temporary_allocator, temporary_allocator) {
TEST(temporary_allocator, add_callback) {
#ifdef PADDLE_WITH_CUDA
const double limit = FLAGS_limit_of_temporary_allocation;
FLAGS_limit_of_temporary_allocation = 10;
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
......@@ -63,7 +63,7 @@ TEST(temporary_allocator, add_callback) {
});
{ gpu_alloc.Allocate(100); }
PADDLE_ENFORCE(deleted);
FLAGS_limit_of_temporary_allocation = -1;
FLAGS_limit_of_temporary_allocation = limit;
#endif
}
......@@ -75,8 +75,8 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) {
auto allocation = cpu_alloc.Allocate(memory_size);
void* address = allocation->ptr();
int numel = memory_size / sizeof(float);
framework::Tensor tensor =
GetTensor<float>(std::move(allocation), framework::make_ddim({numel}));
framework::Tensor tensor = framework::GetTensor<float>(
std::move(allocation), framework::make_ddim({numel}));
PADDLE_ENFORCE_EQ(address, tensor.data<float>());
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
}
......@@ -90,8 +90,8 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) {
auto allocation = gpu_alloc.Allocate(memory_size);
void* address = allocation->ptr();
int numel = memory_size / sizeof(float);
framework::Tensor tensor =
GetTensor<float>(std::move(allocation), framework::make_ddim({numel}));
framework::Tensor tensor = framework::GetTensor<float>(
std::move(allocation), framework::make_ddim({numel}));
PADDLE_ENFORCE_EQ(address, tensor.data<float>());
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
}
......@@ -116,7 +116,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) {
{
auto allocation = cpu_alloc.Allocate(memory_size);
address = allocation->ptr();
framework::Tensor tensor = GetTensor<float>(
framework::Tensor tensor = framework::GetTensor<float>(
std::move(allocation), framework::make_ddim({numel}));
PADDLE_ENFORCE_EQ(address, tensor.data<float>());
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
......@@ -138,7 +138,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) {
{
auto allocation = gpu_alloc.Allocate(memory_size);
address = allocation->ptr();
framework::Tensor tensor = GetTensor<float>(
framework::Tensor tensor = framework::GetTensor<float>(
std::move(allocation), framework::make_ddim({numel}));
PADDLE_ENFORCE_EQ(address, tensor.data<float>());
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
......
......@@ -49,6 +49,9 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker.def(
"kOpNameScopeAttrName",
framework::OpProtoAndCheckerMaker::OpNamescopeAttrName);
op_proto_and_checker_maker.def(
"kOpCreationCallstackAttrName",
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName);
}
} // namespace pybind
......
......@@ -44,6 +44,8 @@ class DataToLoDTensorConverter(object):
self.dtype = 'int64'
elif dtype == core.VarDesc.VarType.FP64:
self.dtype = 'float64'
elif dtype == core.VarDesc.VarType.FP16:
self.dtype = 'float16'
elif dtype == core.VarDesc.VarType.INT32:
self.dtype = 'int32'
elif dtype == core.VarDesc.VarType.UINT8:
......
......@@ -20,6 +20,7 @@ import os
import re
import six
import sys
import traceback
import numpy as np
......@@ -604,6 +605,10 @@ class Operator(object):
if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
del op_attrs[role_var_name]
callstack_var_name = op_maker.kOpCreationCallstackAttrName()
op_attrs[callstack_var_name] = list(
reversed(traceback.format_stack()))[1:]
if len(self.desc.type()) != 0:
return
if type is None:
......
......@@ -18,6 +18,7 @@ from . import framework
import numpy as np
import contextlib
from .core import VarDesc
from . import unique_name
__all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
......@@ -207,16 +208,39 @@ class UniformInitializer(Initializer):
# Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(['gaussian_random', 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
op = block._prepend_op(
type="uniform_random",
outputs={"Out": var},
outputs={"Out": out_var},
attrs={
"shape": var.shape,
"dtype": int(var.dtype),
"dtype": out_dtype,
"min": self._low,
"max": self._high,
"seed": self._seed
})
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
return op
......@@ -261,17 +285,39 @@ class NormalInitializer(Initializer):
# Initialization Ops should be prepended and not appended
if self._seed == 0:
self._seed = block.program.random_seed
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(['gaussian_random', 'tmp'])),
shape=var.shape,
dtype=out_dtype,
type=VarDesc.VarType.LOD_TENSOR,
persistable=False)
else:
out_dtype = var.dtype
out_var = var
op = block._prepend_op(
type="gaussian_random",
outputs={"Out": var},
outputs={"Out": out_var},
attrs={
"shape": var.shape,
"dtype": int(var.dtype),
"dtype": out_dtype,
"mean": self._mean,
"std": self._std_dev,
"seed": self._seed,
"use_mkldnn": False
})
if var.dtype == VarDesc.VarType.FP16:
block.append_op(
type="cast",
inputs={"X": out_var},
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
return op
......
......@@ -2802,6 +2802,10 @@ def batch_norm(input,
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
# use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16:
dtype = core.VarDesc.VarType.FP32
input_shape = input.shape
if data_layout == 'NCHW':
channel_num = input_shape[1]
......@@ -2836,7 +2840,7 @@ def batch_norm(input,
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=input.dtype)
dtype=dtype)
mean.stop_gradient = True
variance = helper.create_parameter(
......@@ -2846,7 +2850,7 @@ def batch_norm(input,
trainable=False,
do_model_average=do_model_average_for_mean_and_var),
shape=param_shape,
dtype=input.dtype)
dtype=dtype)
variance.stop_gradient = True
# create output
......@@ -7944,7 +7948,7 @@ def unstack(x, axis=0, num=None):
num = x.shape[axis]
outs = []
for _ in num:
for _ in range(num):
outs.append(helper.create_variable_for_type_inference(x.dtype))
helper.append_op(
......
......@@ -368,6 +368,8 @@ class OpTest(unittest.TestCase):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
return [place]
else:
return []
else:
return []
places = [fluid.CPUPlace()]
......
......@@ -22,8 +22,10 @@ from op_test import OpTest
class TestAccuracyOp(OpTest):
def setUp(self):
self.op_type = "accuracy"
self.dtype = np.float32
self.init_dtype()
n = 8192
infer = np.random.random((n, 1)).astype("float32")
infer = np.random.random((n, 1)).astype(self.dtype)
indices = np.random.randint(0, 2, (n, 1))
label = np.random.randint(0, 2, (n, 1))
self.inputs = {'Out': infer, 'Indices': indices, "Label": label}
......@@ -34,14 +36,25 @@ class TestAccuracyOp(OpTest):
num_correct += 1
break
self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype("float32"),
'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype),
'Correct': np.array([num_correct]).astype("int32"),
'Total': np.array([n]).astype("int32")
}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestAccuracyOpFp16(TestAccuracyOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
if __name__ == '__main__':
unittest.main()
......@@ -21,14 +21,16 @@ from op_test import OpTest
class ElementwiseDivOp(OpTest):
def setUp(self):
self.op_type = "elementwise_div"
self.dtype = np.float32
self.init_dtype()
""" Warning
CPU gradient check error!
'X': np.random.random((32,84)).astype("float32"),
'Y': np.random.random((32,84)).astype("float32")
"""
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
'X': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
}
self.outputs = {'Out': np.divide(self.inputs['X'], self.inputs['Y'])}
......@@ -46,6 +48,9 @@ class ElementwiseDivOp(OpTest):
self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
def init_dtype(self):
pass
class TestElementwiseDivOp_scalar(ElementwiseDivOp):
def setUp(self):
......@@ -126,5 +131,21 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
}
class TestElementwiseDivOpFp16(ElementwiseDivOp):
def init_dtype(self):
self.dtype = np.float16
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=1)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=1, no_grad_set=set('Y'))
if __name__ == '__main__':
unittest.main()
......@@ -135,5 +135,10 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
}
class TestElementwiseMulOpFp16(ElementwiseMulOp):
def init_dtype(self):
self.dtype = np.float16
if __name__ == '__main__':
unittest.main()
......@@ -22,12 +22,22 @@ from op_test import OpTest
class TestFillZerosLikeOp(OpTest):
def setUp(self):
self.op_type = "fill_zeros_like"
self.inputs = {'X': np.random.random((219, 232)).astype("float32")}
self.dtype = np.float32
self.init_dtype()
self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)}
self.outputs = {'Out': np.zeros_like(self.inputs["X"])}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestFillZerosLikeOpFp16(TestFillZerosLikeOp):
def init_dtype(self):
self.dtype = np.float16
if __name__ == "__main__":
unittest.main()
......@@ -24,11 +24,13 @@ from op_test import OpTest
class TestMomentumOp1(OpTest):
def setUp(self):
self.op_type = "momentum"
self.dtype = np.float32
self.init_dtype()
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
param = np.random.random((123, 321)).astype(self.dtype)
grad = np.random.random((123, 321)).astype(self.dtype)
velocity = np.zeros((123, 321)).astype(self.dtype)
learning_rate = np.array([0.001]).astype(self.dtype)
mu = 0.0001
use_nesterov = False
......@@ -50,10 +52,21 @@ class TestMomentumOp1(OpTest):
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def init_dtype(self):
pass
def test_check_output(self):
self.check_output()
class TestMomentumOpFp16(TestMomentumOp1):
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-3)
class TestMomentumOp2(OpTest):
'''Test Momentum with default values for attributes
'''
......
......@@ -69,7 +69,7 @@ class TestOperator(unittest.TestCase):
set(mul_op.attr_names),
set([
"x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var",
"op_namescope"
"op_namescope", "op_callstack"
]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
......
......@@ -23,8 +23,11 @@ class TestTopkOp(OpTest):
def setUp(self):
self.set_args()
self.op_type = "top_k"
self.dtype = np.float32
self.init_dtype()
k = self.top_k
input = np.random.random((self.row, k)).astype("float32")
input = np.random.random((self.row, k)).astype(self.dtype)
output = np.ndarray((self.row, k))
indices = np.ndarray((self.row, k)).astype("int64")
......@@ -38,6 +41,9 @@ class TestTopkOp(OpTest):
self.outputs = {'Out': output, 'Indices': indices}
def init_dtype(self):
pass
def set_args(self):
self.row = 32
self.top_k = 1
......@@ -46,6 +52,11 @@ class TestTopkOp(OpTest):
self.check_output()
class TestTopkOpFp16(TestTopkOp):
def init_dtype(self):
self.dtype = np.float16
class TestTopkOp3d(OpTest):
def setUp(self):
self.op_type = "top_k"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册