提交 6f5e64af 编写于 作者: Y Yang Yu

Merge branch 'feature/check_nan_executor' into feature/rnn_gradient_check

...@@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc) ...@@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc DEPS glog) cc_library(scope SRCS scope.cc DEPS glog)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
...@@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) ...@@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
...@@ -64,4 +67,4 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool) ...@@ -64,4 +67,4 @@ cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init) cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_transform.h"
namespace paddle {
namespace framework {
DataTransformFnMap& DataTransformFnMap::Instance() {
static DataTransformFnMap data_transform_map;
return data_transform_map;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <utility>
#include <vector>
#include "paddle/framework/op_kernel_type.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/variable.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/macros.h"
namespace paddle {
namespace framework {
using DataTransformFN =
std::function<void(const std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
struct KernelTypePairHash {
static void HashCombine(const OpKernelType& t, std::size_t* seed) {
OpKernelType::Hash kernel_type_hasher;
(*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
}
size_t operator()(const KernelTypePair& kernel_pair) const {
std::size_t seed = 0;
HashCombine(kernel_pair.first, &seed);
HashCombine(kernel_pair.second, &seed);
return seed;
}
};
using DataTransformMap =
std::unordered_map<KernelTypePair, DataTransformFN, KernelTypePairHash>;
class DataTransformFnMap {
public:
static DataTransformFnMap& Instance();
bool Has(const KernelTypePair& key_pair) const {
return map_.find(key_pair) != map_.end();
}
void Insert(const OpKernelType& left, const OpKernelType& right,
const DataTransformFN& data_tranform_fn) {
Insert(std::make_pair(left, right), data_tranform_fn);
}
void Insert(const KernelTypePair& kernel_type_pair,
const DataTransformFN& data_tranform_fn) {
PADDLE_ENFORCE(!Has(kernel_type_pair),
"KernelTypePair %s has been registered", "");
map_.insert({kernel_type_pair, data_tranform_fn});
}
const DataTransformFN& Get(const KernelTypePair& key_pair) const {
auto data_transformer = GetNullable(key_pair);
PADDLE_ENFORCE_NOT_NULL(data_transformer,
"DataTransformFN should not be NULL");
return *data_transformer;
}
const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const {
auto it = map_.find(key_pair);
if (it == map_.end()) {
return nullptr;
} else {
return &(it->second);
}
}
const DataTransformMap& Map() const { return map_; }
private:
DataTransformFnMap() = default;
DataTransformMap map_;
DISABLE_COPY_AND_ASSIGN(DataTransformFnMap);
};
// generate unique name with __LINE__
// refs https://stackoverflow.com/questions/1597007
#define TOKENPASTE(x, y) x##y
#define TOKENPASTE2(x, y) TOKENPASTE(x, y)
#define REGISTER_DATA_TRANSFORM_FN(from, to, fn) \
static int TOKENPASTE2(fn_, __LINE__)() { \
::paddle::framework::DataTransformFnMap::Instance().Insert(from, to, fn); \
return 0; \
} \
static int TOKENPASTE2(var_, __LINE__) __attribute__((unused)) = \
TOKENPASTE2(fn_, __LINE__)()
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_transform.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
using namespace platform;
int test_value = 0;
OpKernelType kernel_type_1(proto::DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
OpKernelType kernel_type_2(proto::DataType::FP32, CUDAPlace(0),
DataLayout::kNCHW, LibraryType::kCUDNN);
OpKernelType kernel_type_3(proto::DataType::FP16, CUDAPlace(0),
DataLayout::kNCHW, LibraryType::kCUDNN);
void type1_to_type2(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
test_value++;
}
void type2_to_type3(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
test_value--;
}
void type1_to_type3(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
test_value += 2;
}
} // namespace framework
} // namespace paddle
namespace frw = paddle::framework;
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_2,
frw::type1_to_type2);
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_2, frw::kernel_type_3,
frw::type2_to_type3);
REGISTER_DATA_TRANSFORM_FN(frw::kernel_type_1, frw::kernel_type_3,
frw::type1_to_type3);
TEST(DataTransform, Register) {
using namespace paddle::framework;
using namespace paddle::platform;
auto& instance = DataTransformFnMap::Instance();
ASSERT_EQ(instance.Map().size(), 3UL);
std::vector<DeviceContext*> ctx;
paddle::framework::Variable in;
paddle::framework::Variable out;
instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_2))(ctx, in,
&out);
ASSERT_EQ(test_value, 1);
instance.Get(std::make_pair(frw::kernel_type_2, frw::kernel_type_3))(ctx, in,
&out);
ASSERT_EQ(test_value, 0);
instance.Get(std::make_pair(frw::kernel_type_1, frw::kernel_type_3))(ctx, in,
&out);
ASSERT_EQ(test_value, 2);
}
...@@ -14,18 +14,17 @@ limitations under the License. */ ...@@ -14,18 +14,17 @@ limitations under the License. */
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <set> #include <set>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/framework/feed_fetch_type.h" #include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h"
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { ...@@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
} }
} }
static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) {
if (tensor.type().hash_code() != typeid(float).hash_code() &&
tensor.type().hash_code() != typeid(double).hash_code()) {
return;
}
if (tensor.memory_size() == 0) {
return;
}
PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name);
PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name);
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
...@@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString(); VLOG(3) << op->DebugString();
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
} }
if (create_local_scope) { if (create_local_scope) {
scope->DeleteScope(local_scope); scope->DeleteScope(local_scope);
......
...@@ -71,7 +71,7 @@ bool InitDevices(const std::vector<std::string> &devices) { ...@@ -71,7 +71,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified CPU device, create CPU by Default."; LOG(WARNING) << "Not specified CPU device, create CPU by Default.";
} }
platform::DeviceContextPool::Create(places); platform::DeviceContextPool::Init(places);
return true; return true;
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
...@@ -387,8 +388,8 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -387,8 +388,8 @@ void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto dev_ctx = pool.Borrow(place); auto dev_ctx = pool.Get(place);
// check if op[type] has kernel registered. // check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels(); auto& all_op_kernels = AllOpKernels();
...@@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
expected_kernel_key); expected_kernel_key);
} }
kernel_iter->second->Compute(ctx); if (actual_kernel_key == expected_kernel_key) {
kernel_iter->second->Compute(ctx);
} else {
Scope& op_scope = scope.NewScope();
auto input_vars = this->InputVars();
for (auto var_name : input_vars) {
op_scope.Var(var_name);
}
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
// TODO(qijun) get appropriate DataTransformFN from global map
framework::DataTransformFN trans_fun = nullptr;
// Wait for transform starting
dev_ctx->Wait();
for (auto var_name : input_vars) {
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
op_scope.FindVar(var_name));
}
// Wait for data transform finishing
for (auto ctx : trans_dev_ctx_vec) {
ctx->Wait();
}
// Create a new ExecutionContext
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
kernel_iter->second->Compute(op_ctx);
}
} }
OpKernelType OperatorWithKernel::GetActualKernelType( OpKernelType OperatorWithKernel::GetActualKernelType(
......
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/data_type.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -205,5 +208,100 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) { ...@@ -205,5 +208,100 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
src_ptr, size); src_ptr, size);
} }
template <typename Predicate, typename DevCtx>
struct AnyDTypeVisitor {
Predicate predicate_;
const Tensor& tensor_;
const DevCtx& ctx_;
Tensor* out_;
AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx,
Tensor* out)
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
template <typename T>
void operator()() const {
auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenScalar<bool>::From(*out_);
o.device(*ctx_.eigen_device()) = predicate_(t).any();
}
};
template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out));
}
template <typename Predicate>
struct AnyVisitor : public boost::static_visitor<bool> {
const framework::Tensor& tensor_;
Predicate predicate_;
AnyVisitor(const framework::Tensor& tensor, Predicate predicate)
: tensor_(tensor), predicate_(std::move(predicate)) {}
template <typename Place>
bool operator()(const Place& place) const {
framework::Tensor out;
out.Resize({1});
out.mutable_data<bool>(place);
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
AnyImpl(predicate_, tensor_, *ctx, &out);
return this->GetResult(out, place);
}
bool GetResult(const framework::Tensor& out,
const platform::CUDAPlace& gpu) const {
platform::CPUPlace cpu;
framework::Tensor tmp;
tmp.Resize({1});
tmp.mutable_data<bool>(cpu);
platform::DeviceContextPool::Instance().Get(gpu)->Wait();
CopyFrom(out, cpu, &tmp);
platform::DeviceContextPool::Instance().Get(gpu)->Wait();
return GetResult(tmp, cpu);
}
bool GetResult(const framework::Tensor& out,
const platform::CPUPlace& cpu) const {
return *out.data<bool>();
}
};
template <typename Predicate>
inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
AnyVisitor<Predicate> visitor(tensor, predicate);
auto place = tensor.place();
return platform::VisitPlace(place, visitor);
}
struct HasNANPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isnan()) {
return eigen_vec.isnan();
}
};
inline bool HasNAN(const framework::Tensor& tensor) {
HasNANPredicate predicate;
return Any(tensor, predicate);
}
struct HasInfPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isinf()) {
return eigen_vec.isinf();
}
};
inline bool HasInf(const framework::Tensor& tensor) {
HasInfPredicate predicate;
return Any(tensor, predicate);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "paddle/framework/tensor_util.h" #include "paddle/framework/tensor_util.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath>
#include <string> #include <string>
namespace paddle { namespace paddle {
...@@ -230,5 +231,28 @@ TEST(CopyToVector, Tensor) { ...@@ -230,5 +231,28 @@ TEST(CopyToVector, Tensor) {
#endif #endif
} }
TEST(IsNAN, CPU) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
float* buf = src.mutable_data<float>({3}, CPUPlace());
buf[0] = 0.0;
buf[1] = NAN;
buf[2] = 0.0;
ASSERT_TRUE(HasNAN(src));
}
TEST(IsInf, CPU) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
double* buf = src.mutable_data<double>({3}, CPUPlace());
buf[0] = 1.0;
buf[1] = INFINITY;
buf[2] = 0.0;
ASSERT_TRUE(HasInf(src));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,7 +29,7 @@ bool MKLDNNLRNLayer::init(const LayerMap& layerMap, ...@@ -29,7 +29,7 @@ bool MKLDNNLRNLayer::init(const LayerMap& layerMap,
} }
/* the size of inputs for norm-layer is 1 */ /* the size of inputs for norm-layer is 1 */
CHECK_EQ(config_.inputs_size(), 1UL); CHECK_EQ(config_.inputs_size(), 1);
const NormConfig& conf = config_.inputs(0).norm_conf(); const NormConfig& conf = config_.inputs(0).norm_conf();
localSize_ = conf.size(); localSize_ = conf.size();
alpha_ = conf.scale(); alpha_ = conf.scale();
......
...@@ -22,8 +22,8 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -22,8 +22,8 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
...@@ -32,7 +32,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -32,7 +32,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
} }
}; };
...@@ -41,11 +41,11 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,11 +41,11 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sigmoid operator"); AddInput("X", "Input of Sigmoid operator");
AddOutput("Y", "Output of Sigmoid operator"); AddOutput("Out", "Output of Sigmoid operator");
AddComment(R"DOC( AddComment(R"DOC(
Sigmoid Activation Operator Sigmoid Activation Operator
$$y = \frac{1}{1 + e^{-x}}$$ $$out = \frac{1}{1 + e^{-x}}$$
)DOC"); )DOC");
} }
...@@ -56,11 +56,11 @@ class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,11 +56,11 @@ class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LogSigmoid operator"); AddInput("X", "Input of LogSigmoid operator");
AddOutput("Y", "Output of LogSigmoid operator"); AddOutput("Out", "Output of LogSigmoid operator");
AddComment(R"DOC( AddComment(R"DOC(
Logsigmoid Activation Operator Logsigmoid Activation Operator
$$y = \log \frac{1}{1 + e^{-x}}$$ $$out = \log \frac{1}{1 + e^{-x}}$$
)DOC"); )DOC");
} }
...@@ -71,11 +71,11 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,11 +71,11 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker) ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Exp operator"); AddInput("X", "Input of Exp operator");
AddOutput("Y", "Output of Exp operator"); AddOutput("Out", "Output of Exp operator");
AddComment(R"DOC( AddComment(R"DOC(
Exp Activation Operator. Exp Activation Operator.
$y = e^x$ $out = e^x$
)DOC"); )DOC");
} }
...@@ -86,11 +86,11 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,11 +86,11 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator"); AddInput("X", "Input of Relu operator");
AddOutput("Y", "Output of Relu operator"); AddOutput("Out", "Output of Relu operator");
AddComment(R"DOC( AddComment(R"DOC(
Relu Activation Operator. Relu Activation Operator.
$y = \max(x, 0)$ $out = \max(x, 0)$
)DOC"); )DOC");
} }
...@@ -101,12 +101,12 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,12 +101,12 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator"); AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator"); AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f); AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
AddComment(R"DOC( AddComment(R"DOC(
LeakyRelu Activation Operator. LeakyRelu Activation Operator.
$y = \max(x, \alpha * x)$ $out = \max(x, \alpha * x)$
)DOC"); )DOC");
} }
...@@ -117,13 +117,13 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -117,13 +117,13 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator"); AddInput("X", "Input of Softshrink operator");
AddOutput("Y", "Output of Softshrink operator"); AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f); AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
AddComment(R"DOC( AddComment(R"DOC(
Softshrink Activation Operator. Softshrink Activation Operator.
$$ $$
y = \begin{cases} out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\ x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\ x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise} 0, \text{otherwise}
...@@ -139,11 +139,11 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -139,11 +139,11 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator"); AddInput("X", "Input of Tanh operator");
AddOutput("Y", "Output of Tanh operator"); AddOutput("Out", "Output of Tanh operator");
AddComment(R"DOC( AddComment(R"DOC(
Tanh Activation Operator. Tanh Activation Operator.
$$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ $$out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC"); )DOC");
} }
...@@ -154,11 +154,11 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -154,11 +154,11 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of TanhShrink operator"); AddInput("X", "Input of TanhShrink operator");
AddOutput("Y", "Output of TanhShrink operator"); AddOutput("Out", "Output of TanhShrink operator");
AddComment(R"DOC( AddComment(R"DOC(
TanhShrink Activation Operator. TanhShrink Activation Operator.
$$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ $$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC"); )DOC");
} }
...@@ -169,14 +169,14 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -169,14 +169,14 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator"); AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator"); AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink") AddAttr<float>("threshold", "The value of threshold for HardShrink")
.SetDefault(0.5f); .SetDefault(0.5f);
AddComment(R"DOC( AddComment(R"DOC(
HardShrink Activation Operator. HardShrink Activation Operator.
$$ $$
y = \begin{cases} out = \begin{cases}
x, \text{if } x > \lambda \\ x, \text{if } x > \lambda \\
x, \text{if } x < -\lambda \\ x, \text{if } x < -\lambda \\
0, \text{otherwise} 0, \text{otherwise}
...@@ -192,11 +192,11 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -192,11 +192,11 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker) SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator"); AddInput("X", "Input of Sqrt operator");
AddOutput("Y", "Output of Sqrt operator"); AddOutput("Out", "Output of Sqrt operator");
AddComment(R"DOC( AddComment(R"DOC(
Sqrt Activation Operator. Sqrt Activation Operator.
$y = \sqrt{x}$ $out = \sqrt{x}$
)DOC"); )DOC");
} }
...@@ -207,11 +207,11 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -207,11 +207,11 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker) AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator"); AddInput("X", "Input of Abs operator");
AddOutput("Y", "Output of Abs operator"); AddOutput("Out", "Output of Abs operator");
AddComment(R"DOC( AddComment(R"DOC(
Abs Activation Operator. Abs Activation Operator.
$y = |x|$ $out = |x|$
)DOC"); )DOC");
} }
...@@ -222,11 +222,11 @@ class CeilOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -222,11 +222,11 @@ class CeilOpMaker : public framework::OpProtoAndCheckerMaker {
CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker) CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Ceil operator"); AddInput("X", "Input of Ceil operator");
AddOutput("Y", "Output of Ceil operator"); AddOutput("Out", "Output of Ceil operator");
AddComment(R"DOC( AddComment(R"DOC(
Ceil Activation Operator. Ceil Activation Operator.
$y = ceil(x)$ $out = ceil(x)$
)DOC"); )DOC");
} }
...@@ -237,11 +237,11 @@ class FloorOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -237,11 +237,11 @@ class FloorOpMaker : public framework::OpProtoAndCheckerMaker {
FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker) FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Floor operator"); AddInput("X", "Input of Floor operator");
AddOutput("Y", "Output of Floor operator"); AddOutput("Out", "Output of Floor operator");
AddComment(R"DOC( AddComment(R"DOC(
Floor Activation Operator. Floor Activation Operator.
$y = floor(x)$ $out = floor(x)$
)DOC"); )DOC");
} }
...@@ -252,11 +252,11 @@ class RoundOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -252,11 +252,11 @@ class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker) RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Round operator"); AddInput("X", "Input of Round operator");
AddOutput("Y", "Output of Round operator"); AddOutput("Out", "Output of Round operator");
AddComment(R"DOC( AddComment(R"DOC(
Round Activation Operator. Round Activation Operator.
$y = [x]$ $out = [x]$
)DOC"); )DOC");
} }
...@@ -267,11 +267,11 @@ class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -267,11 +267,11 @@ class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker) ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Reciprocal operator"); AddInput("X", "Input of Reciprocal operator");
AddOutput("Y", "Output of Reciprocal operator"); AddOutput("Out", "Output of Reciprocal operator");
AddComment(R"DOC( AddComment(R"DOC(
Reciprocal Activation Operator. Reciprocal Activation Operator.
$$y = \frac{1}{x}$$ $$out = \frac{1}{x}$$
)DOC"); )DOC");
} }
...@@ -282,11 +282,11 @@ class LogOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -282,11 +282,11 @@ class LogOpMaker : public framework::OpProtoAndCheckerMaker {
LogOpMaker(OpProto *proto, OpAttrChecker *op_checker) LogOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Log operator"); AddInput("X", "Input of Log operator");
AddOutput("Y", "Output of Log operator"); AddOutput("Out", "Output of Log operator");
AddComment(R"DOC( AddComment(R"DOC(
Log Activation Operator. Log Activation Operator.
$y = \ln(x)$ $out = \ln(x)$
Natural logarithm of x. Natural logarithm of x.
...@@ -299,11 +299,11 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -299,11 +299,11 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker) SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Square operator"); AddInput("X", "Input of Square operator");
AddOutput("Y", "Output of Square operator"); AddOutput("Out", "Output of Square operator");
AddComment(R"DOC( AddComment(R"DOC(
Square Activation Operator. Square Activation Operator.
$y = x^2$ $out = x^2$
)DOC"); )DOC");
} }
...@@ -314,11 +314,11 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -314,11 +314,11 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker) SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softplus operator"); AddInput("X", "Input of Softplus operator");
AddOutput("Y", "Output of Softplus operator"); AddOutput("Out", "Output of Softplus operator");
AddComment(R"DOC( AddComment(R"DOC(
Softplus Activation Operator. Softplus Activation Operator.
$y = \ln(1 + e^{x})$ $out = \ln(1 + e^{x})$
)DOC"); )DOC");
} }
...@@ -329,11 +329,11 @@ class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -329,11 +329,11 @@ class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker {
SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker) SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softsign operator"); AddInput("X", "Input of Softsign operator");
AddOutput("Y", "Output of Softsign operator"); AddOutput("Out", "Output of Softsign operator");
AddComment(R"DOC( AddComment(R"DOC(
Softsign Activation Operator. Softsign Activation Operator.
$$y = \frac{x}{1 + |x|}$$ $$out = \frac{x}{1 + |x|}$$
)DOC"); )DOC");
} }
...@@ -344,7 +344,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -344,7 +344,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator"); AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator"); AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu") AddAttr<float>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<float>(0)); .SetDefault(static_cast<float>(0));
AddAttr<float>("t_max", "The max marginal value of BRelu") AddAttr<float>("t_max", "The max marginal value of BRelu")
...@@ -352,7 +352,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -352,7 +352,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
BRelu Activation Operator. BRelu Activation Operator.
$y = \max(\min(x, t_{min}), t_{max})$ $out = \max(\min(x, t_{min}), t_{max})$
)DOC"); )DOC");
} }
...@@ -363,13 +363,13 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -363,13 +363,13 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator"); AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator"); AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu") AddAttr<float>("threshold", "The threshold value of SoftRelu")
.SetDefault(40.0f); .SetDefault(40.0f);
AddComment(R"DOC( AddComment(R"DOC(
SoftRelu Activation Operator. SoftRelu Activation Operator.
$y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ $out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
)DOC"); )DOC");
} }
...@@ -380,7 +380,7 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -380,7 +380,7 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker) ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator"); AddInput("X", "Input of ELU operator");
AddOutput("Y", "Output of ELU operator"); AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f); AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
AddComment(R"DOC( AddComment(R"DOC(
ELU Activation Operator. ELU Activation Operator.
...@@ -388,7 +388,7 @@ ELU Activation Operator. ...@@ -388,7 +388,7 @@ ELU Activation Operator.
Applies the following element-wise computation on the input according to Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289. https://arxiv.org/abs/1511.07289.
$y = \max(0, x) + \min(0, \alpha * (e^x - 1))$ $out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
)DOC"); )DOC");
} }
...@@ -399,13 +399,13 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -399,13 +399,13 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker) Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator"); AddInput("X", "Input of Relu6 operator");
AddOutput("Y", "Output of Relu6 operator"); AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6") AddAttr<float>("threshold", "The threshold value of Relu6")
.SetDefault(6.0f); .SetDefault(6.0f);
AddComment(R"DOC( AddComment(R"DOC(
Relu6 Activation Operator. Relu6 Activation Operator.
$y = \min(\max(0, x), 6)$ $out = \min(\max(0, x), 6)$
)DOC"); )DOC");
} }
...@@ -416,12 +416,12 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -416,12 +416,12 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker {
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker) PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator"); AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator"); AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f); AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
AddComment(R"DOC( AddComment(R"DOC(
Pow Activation Operator. Pow Activation Operator.
$y = x^{factor}$ $out = x^{factor}$
)DOC"); )DOC");
} }
...@@ -432,7 +432,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -432,7 +432,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator"); AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator"); AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input") AddAttr<float>("scale_a", "The scale parameter of a for the input")
.SetDefault(2.0f / 3.0f); .SetDefault(2.0f / 3.0f);
AddAttr<float>("scale_b", "The scale parameter of b for the input") AddAttr<float>("scale_b", "The scale parameter of b for the input")
...@@ -440,7 +440,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -440,7 +440,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
STanh Activation Operator. STanh Activation Operator.
$$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ $$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
)DOC"); )DOC");
} }
...@@ -451,14 +451,14 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -451,14 +451,14 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator"); AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Y", "Output of ThresholdedRelu operator"); AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation") AddAttr<float>("threshold", "The threshold location of activation")
.SetDefault(1.0f); .SetDefault(1.0f);
AddComment(R"DOC( AddComment(R"DOC(
ThresholdedRelu Activation Operator. ThresholdedRelu Activation Operator.
$$ $$
y = \begin{cases} out = \begin{cases}
x, \text{if } x > threshold \\ x, \text{if } x > threshold \\
0, \text{otherwise} 0, \text{otherwise}
\end{cases} \end{cases}
...@@ -473,7 +473,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -473,7 +473,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator"); AddInput("X", "Input of HardSigmoid operator");
AddOutput("Y", "Output of HardSigmoid operator"); AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid") AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
.SetDefault(0.2f); .SetDefault(0.2f);
AddAttr<float>("offset", "Offset for linear approximation of sigmoid") AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
...@@ -484,7 +484,7 @@ HardSigmoid Activation Operator. ...@@ -484,7 +484,7 @@ HardSigmoid Activation Operator.
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid. which is much faster than sigmoid.
$y = \max(0, \min(1, slope * x + shift))$ $out = \max(0, \min(1, slope * x + shift))$
The slope should be positive. The offset can be either positive or negative. The slope should be positive. The offset can be either positive or negative.
The default slope and shift are set according to the above reference. The default slope and shift are set according to the above reference.
...@@ -499,12 +499,12 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -499,12 +499,12 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker) SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Swish operator"); AddInput("X", "Input of Swish operator");
AddOutput("Y", "Output of Swish operator"); AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f); AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddComment(R"DOC( AddComment(R"DOC(
Swish Activation Operator. Swish Activation Operator.
$$y = \frac{x}{1 + e^{- \beta x}}$$ $$out = \frac{x}{1 + e^{- \beta x}}$$
)DOC"); )DOC");
} }
......
...@@ -27,11 +27,11 @@ class ActivationKernel ...@@ -27,11 +27,11 @@ class ActivationKernel
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y"); auto* Out = context.Output<framework::Tensor>("Out");
Y->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X); auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y); auto out = framework::EigenVector<T>::Flatten(*Out);
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
...@@ -40,7 +40,7 @@ class ActivationKernel ...@@ -40,7 +40,7 @@ class ActivationKernel
for (auto& attr : attrs) { for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first); *attr.second = context.Attr<float>(attr.first);
} }
functor(*place, x, y); functor(*place, x, out);
} }
}; };
...@@ -51,14 +51,15 @@ class ActivationGradKernel ...@@ -51,14 +51,15 @@ class ActivationGradKernel
using T = typename Functor::ELEMENT_TYPE; using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y"); auto* Out = context.Input<framework::Tensor>("Out");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY); auto dout = framework::EigenVector<T>::Flatten(*dOut);
auto x = framework::EigenVector<T>::Flatten(*X); auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y); auto out = framework::EigenVector<T>::Flatten(*Out);
auto dx = framework::EigenVector<T>::Flatten(*dX); auto dx = framework::EigenVector<T>::Flatten(*dX);
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
...@@ -67,7 +68,7 @@ class ActivationGradKernel ...@@ -67,7 +68,7 @@ class ActivationGradKernel
for (auto& attr : attrs) { for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first); *attr.second = context.Attr<float>(attr.first);
} }
functor(*place, x, y, dy, dx); functor(*place, x, out, dout, dx);
} }
}; };
...@@ -83,17 +84,18 @@ struct BaseActivationFunctor { ...@@ -83,17 +84,18 @@ struct BaseActivationFunctor {
// sigmoid(x) = 1 / (1 + exp(-x)) // sigmoid(x) = 1 / (1 + exp(-x))
template <typename T> template <typename T>
struct SigmoidFunctor : public BaseActivationFunctor<T> { struct SigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp()); out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
} }
}; };
template <typename T> template <typename T>
struct SigmoidGradFunctor : public BaseActivationFunctor<T> { struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * y * (static_cast<T>(1) - y); void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out * (static_cast<T>(1) - out);
} }
}; };
...@@ -101,7 +103,7 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -101,7 +103,7 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
// For numerical stability, we can use the log-sum-exp trick: // For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as: // We can rewrite the above equation as:
// y = -log( exp(0) + exp(-x)) [since exp(0) = 1] // out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
// = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) // = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
// = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - // = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
// max(-x, 0))) // max(-x, 0)))
...@@ -112,10 +114,10 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -112,10 +114,10 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
// + exp(-x - max(-x, 0)))) // + exp(-x - max(-x, 0))))
template <typename T> template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> { struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0) auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
} }
}; };
...@@ -124,62 +126,66 @@ struct LogSigmoidFunctor : public BaseActivationFunctor<T> { ...@@ -124,62 +126,66 @@ struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
// exp(-x - max(-x, 0))) // exp(-x - max(-x, 0)))
template <typename T> template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> { struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0) auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
dx.device(d) = dx.device(d) =
dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
} }
}; };
// exp(x) = e^x // exp(x) = e^x
template <typename T> template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> { struct ExpFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.exp(); out.device(d) = x.exp();
} }
}; };
template <typename T> template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> { struct ExpGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * y; void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out;
} }
}; };
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
template <typename T> template <typename T>
struct ReluFunctor : public BaseActivationFunctor<T> { struct ReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)); out.device(d) = x.cwiseMax(static_cast<T>(0));
} }
}; };
template <typename T> template <typename T>
struct ReluGradFunctor : public BaseActivationFunctor<T> { struct ReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>(); void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>();
} }
}; };
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T> template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> { struct TanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.tanh(); out.device(d) = x.tanh();
} }
}; };
template <typename T> template <typename T>
struct TanhGradFunctor : public BaseActivationFunctor<T> { struct TanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * (static_cast<T>(1) - y * y); void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) - out * out);
} }
}; };
...@@ -187,17 +193,18 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -187,17 +193,18 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T> template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> { struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x - x.tanh(); out.device(d) = x - x.tanh();
} }
}; };
template <typename T> template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> { struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * (x.tanh() * x.tanh()); void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x.tanh() * x.tanh());
} }
}; };
...@@ -210,11 +217,11 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> { ...@@ -210,11 +217,11 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval(); auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval(); auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
y.device(d) = x * (temp1 + temp2); out.device(d) = x * (temp1 + temp2);
} }
}; };
...@@ -226,11 +233,12 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -226,11 +233,12 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval(); auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval(); auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
}; };
...@@ -243,12 +251,12 @@ struct SoftShrinkFunctor : public BaseActivationFunctor<T> { ...@@ -243,12 +251,12 @@ struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
return {{"lambda", &lambda}}; return {{"lambda", &lambda}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
auto lambdaT = static_cast<T>(lambda); auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval(); auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval(); auto temp2 = (x < -lambdaT).template cast<T>().eval();
y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
} }
}; };
...@@ -258,46 +266,49 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -258,46 +266,49 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}}; return {{"lambda", &lambda}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto lambdaT = static_cast<T>(lambda); auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval(); auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval(); auto temp2 = (x < -lambdaT).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
}; };
// sqrt(x) = x^(1/2) // sqrt(x) = x^(1/2)
template <typename T> template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> { struct SqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.sqrt(); out.device(d) = x.sqrt();
} }
}; };
template <typename T> template <typename T>
struct SqrtGradFunctor : public BaseActivationFunctor<T> { struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
const Y y_conj = Eigen::numext::conj(y); void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0.5) * dy / y_conj; const Out out_conj = Eigen::numext::conj(out);
dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
} }
}; };
// ceil(x) = ceiling(x) // ceil(x) = ceiling(x)
template <typename T> template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> { struct CeilFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.ceil(); out.device(d) = x.ceil();
} }
}; };
template <typename T> template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> { struct ZeroGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0) / x; dx.device(d) = static_cast<T>(0) / x;
} }
}; };
...@@ -305,86 +316,90 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> { ...@@ -305,86 +316,90 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
// floor(x) = flooring(x) // floor(x) = flooring(x)
template <typename T> template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> { struct FloorFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.ceil(); out.device(d) = x.ceil();
} }
}; };
// round(x) = [x] // round(x) = [x]
template <typename T> template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> { struct RoundFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.round(); out.device(d) = x.round();
} }
}; };
// abs(x) = |x| // abs(x) = |x|
template <typename T> template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> { struct AbsFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.abs(); out.device(d) = x.abs();
} }
}; };
template <typename T> template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> { struct AbsGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * x.sign(); void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.sign();
} }
}; };
// reciprocal(x) = 1 / x // reciprocal(x) = 1 / x
template <typename T> template <typename T>
struct ReciprocalFunctor : public BaseActivationFunctor<T> { struct ReciprocalFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = static_cast<T>(1) / x; out.device(d) = static_cast<T>(1) / x;
} }
}; };
template <typename T> template <typename T>
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> { struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * static_cast<T>(-1) * y * y; void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(-1) * out * out;
} }
}; };
// log(x) = natural logarithm of x // log(x) = natural logarithm of x
template <typename T> template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> { struct LogFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.log(); out.device(d) = x.log();
} }
}; };
template <typename T> template <typename T>
struct LogGradFunctor : public BaseActivationFunctor<T> { struct LogGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * (static_cast<T>(1) / x); void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) / x);
} }
}; };
// square(x) = x^2 // square(x) = x^2
template <typename T> template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> { struct SquareFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.square(); out.device(d) = x.square();
} }
}; };
template <typename T> template <typename T>
struct SquareGradFunctor : public BaseActivationFunctor<T> { struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * static_cast<T>(2) * x; void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(2) * x;
} }
}; };
...@@ -399,9 +414,9 @@ struct BReluFunctor : public BaseActivationFunctor<T> { ...@@ -399,9 +414,9 @@ struct BReluFunctor : public BaseActivationFunctor<T> {
return {{"t_min", &t_min}, {"t_max", &t_max}}; return {{"t_min", &t_min}, {"t_max", &t_max}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = out.device(d) =
x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max)); x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
} }
}; };
...@@ -413,9 +428,10 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -413,9 +428,10 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"t_min", &t_min}, {"t_max", &t_max}}; return {{"t_min", &t_min}, {"t_max", &t_max}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max))) ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
.template cast<T>(); .template cast<T>();
} }
...@@ -430,9 +446,9 @@ struct Relu6Functor : public BaseActivationFunctor<T> { ...@@ -430,9 +446,9 @@ struct Relu6Functor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = out.device(d) =
x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold)); x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
} }
}; };
...@@ -443,9 +459,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -443,9 +459,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((x > static_cast<T>(0)) * (x < static_cast<T>(threshold))) ((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
.template cast<T>(); .template cast<T>();
} }
...@@ -458,10 +475,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -458,10 +475,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) // Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T> template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> { struct SoftplusFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Out out) {
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0) auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
} }
}; };
...@@ -471,19 +488,21 @@ struct SoftplusFunctor : public BaseActivationFunctor<T> { ...@@ -471,19 +488,21 @@ struct SoftplusFunctor : public BaseActivationFunctor<T> {
// exp(x - max(x, 0))) // exp(x - max(x, 0)))
template <typename T> template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> { struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0) auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); dx.device(d) =
dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
} }
}; };
// softsign(x) = x / (1 + |x|) // softsign(x) = x / (1 + |x|)
template <typename T> template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> { struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Out out) {
y.device(d) = x / (static_cast<T>(1) + x.abs()); out.device(d) = x / (static_cast<T>(1) + x.abs());
} }
}; };
...@@ -491,10 +510,11 @@ struct SoftsignFunctor : public BaseActivationFunctor<T> { ...@@ -491,10 +510,11 @@ struct SoftsignFunctor : public BaseActivationFunctor<T> {
// Taken from https://en.wikipedia.org/wiki/Activation_function // Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T> template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> { struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
dx.device(d) = dx.device(d) =
dy * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square()); dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
} }
}; };
...@@ -505,11 +525,11 @@ struct SoftReluFunctor : public BaseActivationFunctor<T> { ...@@ -505,11 +525,11 @@ struct SoftReluFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
auto tmp = static_cast<T>(threshold); auto tmp = static_cast<T>(threshold);
auto temp = x.cwiseMax(-tmp).cwiseMin(tmp); auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
y.device(d) = (static_cast<T>(1) + temp.exp()).log(); out.device(d) = (static_cast<T>(1) + temp.exp()).log();
} }
}; };
...@@ -519,11 +539,12 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -519,11 +539,12 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto tmp = static_cast<T>(threshold); auto tmp = static_cast<T>(threshold);
auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval(); auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp; dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
} }
}; };
...@@ -534,9 +555,9 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> { ...@@ -534,9 +555,9 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {
return {{"alpha", &alpha}}; return {{"alpha", &alpha}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.cwiseMax(static_cast<T>(alpha) * x); out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
} }
}; };
...@@ -546,12 +567,13 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -546,12 +567,13 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}}; return {{"alpha", &alpha}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(alpha) * auto temp1 = static_cast<T>(alpha) *
(x < static_cast<T>(0)).template cast<T>().eval(); (x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval(); auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
}; };
...@@ -562,11 +584,11 @@ struct ELUFunctor : public BaseActivationFunctor<T> { ...@@ -562,11 +584,11 @@ struct ELUFunctor : public BaseActivationFunctor<T> {
return {{"alpha", &alpha}}; return {{"alpha", &alpha}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)) + out.device(d) = x.cwiseMax(static_cast<T>(0)) +
(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1))) (static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
.cwiseMin(static_cast<T>(0)); .cwiseMin(static_cast<T>(0));
} }
}; };
...@@ -576,10 +598,11 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -576,10 +598,11 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}}; return {{"alpha", &alpha}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>() + void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dy * (y + static_cast<T>(alpha)) * dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
dout * (out + static_cast<T>(alpha)) *
(x < static_cast<T>(0)).template cast<T>(); (x < static_cast<T>(0)).template cast<T>();
} }
}; };
...@@ -591,9 +614,9 @@ struct PowFunctor : public BaseActivationFunctor<T> { ...@@ -591,9 +614,9 @@ struct PowFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}}; return {{"factor", &factor}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x.pow(static_cast<T>(factor)); out.device(d) = x.pow(static_cast<T>(factor));
} }
}; };
...@@ -603,9 +626,10 @@ struct PowGradFunctor : public BaseActivationFunctor<T> { ...@@ -603,9 +626,10 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}}; return {{"factor", &factor}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = dy * static_cast<T>(factor) * void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor - static_cast<T>(1))); x.pow(static_cast<T>(factor - static_cast<T>(1)));
} }
}; };
...@@ -618,9 +642,9 @@ struct STanhFunctor : public BaseActivationFunctor<T> { ...@@ -618,9 +642,9 @@ struct STanhFunctor : public BaseActivationFunctor<T> {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = out.device(d) =
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh(); static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
} }
}; };
...@@ -633,12 +657,13 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -633,12 +657,13 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto a = static_cast<T>(scale_a); auto a = static_cast<T>(scale_a);
auto b = static_cast<T>(scale_b); auto b = static_cast<T>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh(); auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dy * a * b * (static_cast<T>(1) - temp); dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
} }
}; };
...@@ -649,10 +674,10 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor<T> { ...@@ -649,10 +674,10 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
auto th = static_cast<T>(threshold); auto th = static_cast<T>(threshold);
y.device(d) = (x > th).template cast<T>() * x; out.device(d) = (x > th).template cast<T>() * x;
} }
}; };
...@@ -663,10 +688,11 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -663,10 +688,11 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto th = static_cast<T>(threshold); auto th = static_cast<T>(threshold);
dx.device(d) = dy * (x > th).template cast<T>(); dx.device(d) = dout * (x > th).template cast<T>();
} }
}; };
...@@ -678,10 +704,11 @@ struct HardSigmoidFunctor : public BaseActivationFunctor<T> { ...@@ -678,10 +704,11 @@ struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
return {{"slope", &slope}, {"offset", &offset}}; return {{"slope", &slope}, {"offset", &offset}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
auto temp = x * static_cast<T>(slope) + static_cast<T>(offset); auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
y.device(d) = temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1)); out.device(d) =
temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
} }
}; };
...@@ -693,12 +720,13 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -693,12 +720,13 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
return {{"slope", &slope}, {"offset", &offset}}; return {{"slope", &slope}, {"offset", &offset}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
dx.device(d) = void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dy * dx.device(d) = dout *
((y > static_cast<T>(0)) * (y < static_cast<T>(1))).template cast<T>() * ((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
static_cast<T>(slope); .template cast<T>() *
static_cast<T>(slope);
} }
}; };
...@@ -709,9 +737,9 @@ struct SwishFunctor : public BaseActivationFunctor<T> { ...@@ -709,9 +737,9 @@ struct SwishFunctor : public BaseActivationFunctor<T> {
return {{"beta", &beta}}; return {{"beta", &beta}};
} }
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Y y) const { void operator()(Device d, X x, Out out) const {
y.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp()); out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
} }
}; };
...@@ -722,12 +750,13 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -722,12 +750,13 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
return {{"beta", &beta}}; return {{"beta", &beta}};
} }
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Out, typename dOut,
void operator()(Device d, X x, Y y, dY dy, dX dx) const { typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) / auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp()); (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
auto temp2 = temp1 * (static_cast<T>(1) - (beta * y)); auto temp2 = temp1 * (static_cast<T>(1) - (beta * out));
dx.device(d) = dy * ((beta * y) + temp2); dx.device(d) = dout * ((beta * out) + temp2);
} }
}; };
......
...@@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase { ...@@ -35,8 +35,8 @@ class ArrayOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1); PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
size_t offset; size_t offset;
if (platform::is_gpu_place(i_tensor.place())) { if (platform::is_gpu_place(i_tensor.place())) {
......
...@@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { ...@@ -106,8 +106,9 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
} }
auto slice = out->Slice(out_offset, out_offset + len); auto slice = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, framework::CopyFrom(x[x_idx].Slice(start_offset, end_offset), place,
dev_ctx, &slice); dev_ctx, &slice);
......
...@@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase { ...@@ -82,8 +82,8 @@ class AssignOp : public framework::OperatorBase {
out != nullptr, out != nullptr,
"The Output(Out) should not be null if the Input(X) is set."); "The Output(Out) should not be null if the Input(X) is set.");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
} }
......
...@@ -57,8 +57,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -57,8 +57,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Borrow(dev_place); auto& dev_ctx = *pool.Get(dev_place);
framework::ExecutionContext ctx(*this, scope, dev_ctx); framework::ExecutionContext ctx(*this, scope, dev_ctx);
......
...@@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope, ...@@ -195,8 +195,8 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
void CondOp::Run(const Scope& scope, const platform::Place& place) const { void CondOp::Run(const Scope& scope, const platform::Place& place) const {
// get device context from pool // get device context from pool
platform::DeviceContextPool& pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Borrow(place); auto& dev_ctx = *pool.Get(place);
PrepareDataForSubnet(scope, dev_ctx); PrepareDataForSubnet(scope, dev_ctx);
std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope); std::vector<framework::Scope*>& sub_scopes = GetSubScopes(scope);
......
...@@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase { ...@@ -49,8 +49,8 @@ class FeedOp : public framework::OperatorBase {
auto *out_item = out_var->GetMutable<framework::FeedFetchType>(); auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(feed_item, place, dev_ctx, out_item); framework::CopyFrom(feed_item, place, dev_ctx, out_item);
out_item->set_lod(feed_item.lod()); out_item->set_lod(feed_item.lod());
......
...@@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase { ...@@ -52,8 +52,8 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item); CopyFrom(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
dev_ctx.Wait(); dev_ctx.Wait();
......
...@@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -49,8 +49,8 @@ class FillConstantOp : public framework::OperatorBase {
out.mutable_data(dev_place, framework::ToTypeIndex(data_type)); out.mutable_data(dev_place, framework::ToTypeIndex(data_type));
} }
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(dev_place); auto &dev_ctx = *pool.Get(dev_place);
math::set_constant(dev_ctx, &out, value); math::set_constant(dev_ctx, &out, value);
} }
}; };
......
...@@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase { ...@@ -69,8 +69,9 @@ class FillOp : public framework::OperatorBase {
if (!force_cpu && platform::is_gpu_place(place)) { if (!force_cpu && platform::is_gpu_place(place)) {
// Copy tensor to out // Copy tensor to out
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(tensor, place, dev_ctx, &out); framework::CopyFrom(tensor, place, dev_ctx, &out);
} }
} }
......
...@@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase { ...@@ -40,8 +40,8 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
framework::DeserializeFromStream(fin, tensor); framework::DeserializeFromStream(fin, tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
// copy CPU to GPU // copy CPU to GPU
......
...@@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -88,8 +88,9 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
auto slice = out[i].Slice(static_cast<int>(offset), auto slice = out[i].Slice(static_cast<int>(offset),
static_cast<int>(offset + len)); static_cast<int>(offset + len));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin), framework::CopyFrom(x.Slice(static_cast<int>(each_range.begin),
static_cast<int>(each_range.end)), static_cast<int>(each_range.end)),
......
...@@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase { ...@@ -30,8 +30,8 @@ class MergeLoDTensorOp : public framework::OperatorBase {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(dev_place); auto &dev_ctx = *pool.Get(dev_place);
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>(); auto &mask = scope.FindVar(Input("Mask"))->Get<framework::LoDTensor>();
......
...@@ -305,7 +305,7 @@ int main(int argc, char **argv) { ...@@ -305,7 +305,7 @@ int main(int argc, char **argv) {
} }
VLOG(0) << " DeviceCount " << count; VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places); paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
......
...@@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase { ...@@ -272,8 +272,9 @@ class RecurrentOp : public RecurrentBase {
false /*create_local_scope*/); false /*create_local_scope*/);
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// Copy inside::output -> outside::output // Copy inside::output -> outside::output
// outside::output[seq_offset: seq_offset + 1] = inside::output // outside::output[seq_offset: seq_offset + 1] = inside::output
...@@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -326,8 +327,8 @@ class RecurrentGradOp : public RecurrentBase {
auto *program = block->Program(); auto *program = block->Program();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
for (size_t step_id = 0; step_id < seq_len; ++step_id) { for (size_t step_id = 0; step_id < seq_len; ++step_id) {
size_t seq_offset = reverse ? step_id : seq_len - step_id - 1; size_t seq_offset = reverse ? step_id : seq_len - step_id - 1;
......
...@@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -131,8 +131,8 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
auto x_sliced = x.Slice(x_offset, x_offset + len); auto x_sliced = x.Slice(x_offset, x_offset + len);
auto out_sliced = out->Slice(out_offset, out_offset + len); auto out_sliced = out->Slice(out_offset, out_offset + len);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced); framework::CopyFrom(x_sliced, out_sliced.place(), dev_ctx, &out_sliced);
out_offset += len; out_offset += len;
return out_offset; return out_offset;
......
...@@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase { ...@@ -91,8 +91,8 @@ class SaveOp : public framework::OperatorBase {
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(fout, tensor, dev_ctx);
} }
......
...@@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { ...@@ -106,8 +106,8 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type()); dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(place); auto &dev_ctx = *pool.Get(place);
if (dout_var == nullptr) { // dx_tensor fill zero if (dout_var == nullptr) { // dx_tensor fill zero
math::set_constant(dev_ctx, &dx_tensor, 0.0f); math::set_constant(dev_ctx, &dx_tensor, 0.0f);
......
...@@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SoftmaxOp should not be null."); "Input(X) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Y) of SoftmaxOp should not be null."); "Output(Out) of SoftmaxOp should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(x_dims.size() == 2UL, PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix."); "The input of softmax op must be a matrix.");
ctx->SetOutputDim("Y", x_dims); ctx->SetOutputDim("Out", x_dims);
} }
}; };
...@@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"The input tensor of softmax. " "The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions]."); "2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Y", "The normalized values with the same shape as X."); AddOutput("Out", "The normalized values with the same shape as X.");
AddComment(R"DOC( AddComment(R"DOC(
Softmax Operator. Softmax Operator.
...@@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax ...@@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax
operator. operator.
For each row $i$ and each column $j$ in Input(X), we have: For each row $i$ and each column $j$ in Input(X), we have:
$$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$ $$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC"); )DOC");
} }
...@@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Y@GRAD) should be not null."); "Input(Out@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"), PADDLE_ENFORCE_EQ(ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Y")), ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Y) and its gradients should have a same shape."); "Input(Out) and its gradients should have a same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
......
...@@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel<T> { ...@@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
auto* Y = context.Output<Tensor>("Y"); auto* Out = context.Output<Tensor>("Out");
// allocate memory on device. // allocate memory on device.
Y->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<DeviceContext, T>()( math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Y); context.template device_context<DeviceContext>(), X, Out);
} }
}; };
...@@ -40,15 +40,15 @@ template <typename DeviceContext, typename T> ...@@ -40,15 +40,15 @@ template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> { class SoftmaxGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* Y = context.Input<Tensor>("Y"); auto* Out = context.Input<Tensor>("Out");
auto* dY = context.Input<Tensor>(framework::GradVarName("Y")); auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X")); auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
// allocate memory on device. // allocate memory on device.
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
math::SoftmaxGradFunctor<DeviceContext, T>()( math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Y, dY, dX); context.template device_context<DeviceContext>(), Out, dOut, dX);
} }
}; };
......
...@@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase { ...@@ -45,8 +45,8 @@ class SplitLoDTensorOp : public framework::OperatorBase {
auto &x_lod = x.lod(); auto &x_lod = x.lod();
auto &mask_dim = mask.dims(); auto &mask_dim = mask.dims();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Borrow(dev_place); auto &dev_ctx = *pool.Get(dev_place);
std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()}; std::unique_ptr<framework::LoDTensor> cpu_mask{new framework::LoDTensor()};
if (platform::is_cpu_place(mask.place())) { if (platform::is_cpu_place(mask.place())) {
......
...@@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp { ...@@ -40,8 +40,9 @@ class WriteToArrayOp : public ArrayOp {
if (x_tensor.memory_size() > 0) { if (x_tensor.memory_size() > 0) {
auto *out_tensor = &out->at(offset); auto *out_tensor = &out->at(offset);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
CopyFrom(x_tensor, place, dev_ctx, out_tensor); CopyFrom(x_tensor, place, dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod()); out_tensor->set_lod(x_tensor.lod());
...@@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -132,8 +133,9 @@ class ReadFromArrayOp : public ArrayOp {
auto *out_tensor = out->GetMutable<framework::LoDTensor>(); auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, place); size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) { if (offset < x_array.size()) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
auto &dev_ctx = *pool.Borrow(place); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor); framework::CopyFrom(x_array[offset], place, dev_ctx, out_tensor);
out_tensor->set_lod(x_array[offset].lod()); out_tensor->set_lod(x_array[offset].lod());
} else { } else {
......
...@@ -17,7 +17,7 @@ namespace platform { ...@@ -17,7 +17,7 @@ namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr; DeviceContextPool* DeviceContextPool::pool = nullptr;
const platform::DeviceContext* DeviceContextPool::Borrow( const platform::DeviceContext* DeviceContextPool::Get(
const platform::Place& place) { const platform::Place& place) {
auto it = device_contexts_.find(place); auto it = device_contexts_.find(place);
if (it == device_contexts_.end()) { if (it == device_contexts_.end()) {
...@@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow( ...@@ -28,24 +28,6 @@ const platform::DeviceContext* DeviceContextPool::Borrow(
return it->second; return it->second;
} }
std::vector<const platform::DeviceContext*> DeviceContextPool::Borrow(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
std::vector<const platform::DeviceContext*> borrowed_contexts;
for (auto& place : places) {
auto it = device_contexts_.find(place);
if (it != device_contexts_.end()) {
borrowed_contexts.emplace_back(it->second);
} else {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
}
return borrowed_contexts;
}
DeviceContextPool::DeviceContextPool( DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) { const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0); PADDLE_ENFORCE_GT(places.size(), 0);
......
...@@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext { ...@@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext {
std::unique_ptr<Eigen::DefaultDevice> eigen_device_; std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
}; };
template <typename Place>
struct DefaultDeviceContextType;
template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
using TYPE = CPUDeviceContext;
};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice; class EigenCudaStreamDevice;
...@@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext { ...@@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
}; };
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
using T = CUDADeviceContext;
};
class CUDNNDeviceContext : public CUDADeviceContext { class CUDNNDeviceContext : public CUDADeviceContext {
public: public:
explicit CUDNNDeviceContext(CUDAPlace place); explicit CUDNNDeviceContext(CUDAPlace place);
...@@ -109,13 +122,13 @@ class DeviceContextPool { ...@@ -109,13 +122,13 @@ class DeviceContextPool {
public: public:
explicit DeviceContextPool(const std::vector<platform::Place>& places); explicit DeviceContextPool(const std::vector<platform::Place>& places);
static DeviceContextPool& Get() { static DeviceContextPool& Instance() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!"); PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool; return *pool;
} }
/*! \brief Create should only called by Init function */ /*! \brief Create should only called by Init function */
static DeviceContextPool& Create(const std::vector<platform::Place>& places) { static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
if (pool == nullptr) { if (pool == nullptr) {
pool = new DeviceContextPool(places); pool = new DeviceContextPool(places);
} }
...@@ -123,13 +136,14 @@ class DeviceContextPool { ...@@ -123,13 +136,14 @@ class DeviceContextPool {
} }
/*! \brief Return handle of single device context. */ /*! \brief Return handle of single device context. */
const platform::DeviceContext* Borrow(const platform::Place& place); const platform::DeviceContext* Get(const platform::Place& place);
/*! \brief Return handle of multi-device context. */
std::vector<const platform::DeviceContext*> Borrow(
const std::vector<platform::Place>& places);
~DeviceContextPool() {} template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
const Place& place) {
return reinterpret_cast<
const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
}
private: private:
static DeviceContextPool* pool; static DeviceContextPool* pool;
......
...@@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) { ...@@ -71,35 +71,20 @@ TEST(Device, DeviceContextPool) {
using paddle::platform::CPUPlace; using paddle::platform::CPUPlace;
using paddle::platform::CUDAPlace; using paddle::platform::CUDAPlace;
DeviceContextPool& pool = DeviceContextPool::Get(); DeviceContextPool& pool = DeviceContextPool::Instance();
auto cpu_dev_ctx1 = pool.Borrow(CPUPlace()); auto cpu_dev_ctx1 = pool.Get(CPUPlace());
auto cpu_dev_ctx2 = pool.Borrow(CPUPlace()); auto cpu_dev_ctx2 = pool.Get(CPUPlace());
EXPECT_TRUE(cpu_dev_ctx2 == cpu_dev_ctx1); ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1);
std::vector<Place> gpu_places; std::vector<Place> gpu_places;
int count = paddle::platform::GetCUDADeviceCount(); int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
gpu_places.emplace_back(CUDAPlace(i)); auto dev_ctx = pool.Get(CUDAPlace(i));
} ASSERT_NE(dev_ctx, nullptr);
auto dev_ctxs = pool.Borrow(gpu_places);
for (size_t i = 0; i < dev_ctxs.size(); ++i) {
auto* dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctxs[i]);
// check same as CUDAPlace(i)
CUDAPlace place = boost::get<CUDAPlace>(dev_ctx->GetPlace());
EXPECT_EQ(place.GetDeviceId(), static_cast<int>(i));
} }
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
int dev_count = paddle::platform::GetCUDADeviceCount();
if (dev_count <= 1) {
LOG(WARNING) << "Cannot test multi-gpu DeviceContextPool, because the CUDA "
"device count is "
<< dev_count;
return 0;
}
std::vector<paddle::platform::Place> places; std::vector<paddle::platform::Place> places;
places.emplace_back(paddle::platform::CPUPlace()); places.emplace_back(paddle::platform::CPUPlace());
...@@ -109,7 +94,7 @@ int main(int argc, char** argv) { ...@@ -109,7 +94,7 @@ int main(int argc, char** argv) {
} }
VLOG(0) << " DeviceCount " << count; VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places); paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
......
...@@ -144,7 +144,7 @@ int main(int argc, char** argv) { ...@@ -144,7 +144,7 @@ int main(int argc, char** argv) {
} }
VLOG(0) << " DeviceCount " << count; VLOG(0) << " DeviceCount " << count;
paddle::platform::DeviceContextPool::Create(places); paddle::platform::DeviceContextPool::Init(places);
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h" #include "paddle/platform/variant.h"
namespace paddle { namespace paddle {
...@@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &); ...@@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &);
std::ostream &operator<<(std::ostream &, const Place &); std::ostream &operator<<(std::ostream &, const Place &);
template <typename Visitor>
struct PlaceVisitorWrapper
: public boost::static_visitor<typename Visitor::result_type> {
const Visitor &visitor_;
explicit PlaceVisitorWrapper(const Visitor &visitor) : visitor_(visitor) {}
typename Visitor::result_type operator()(const CPUPlace &cpu) const {
return visitor_(cpu);
}
typename Visitor::result_type operator()(const CUDAPlace &cuda) const {
#ifdef PADDLE_WITH_CUDA
return visitor_(cuda);
#else
PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device");
return typename Visitor::result_type();
#endif
}
};
template <typename Visitor>
typename Visitor::result_type VisitPlace(const Place &place,
const Visitor &visitor) {
return boost::apply_visitor(PlaceVisitorWrapper<Visitor>(visitor), place);
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -63,9 +63,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -63,9 +63,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>( auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace())); tensor.dims(), platform::CPUPlace()));
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto dev_ctx = static_cast<const platform::CUDADeviceContext *>( auto dev_ctx = static_cast<const platform::CUDADeviceContext *>(
pool.Borrow(tensor.place())); pool.Get(tensor.place()));
paddle::platform::GpuMemcpyAsync( paddle::platform::GpuMemcpyAsync(
dst_ptr, src_ptr, sizeof(CUR_TYPE) * tensor.numel(), dst_ptr, src_ptr, sizeof(CUR_TYPE) * tensor.numel(),
...@@ -137,9 +138,9 @@ void PyCUDATensorSetFromArray( ...@@ -137,9 +138,9 @@ void PyCUDATensorSetFromArray(
self.Resize(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place); auto *dst = self.mutable_data<T>(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Get(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto dev_ctx = auto dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Borrow(place)); static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(), paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(),
cudaMemcpyHostToDevice, dev_ctx->stream()); cudaMemcpyHostToDevice, dev_ctx->stream());
} }
......
...@@ -36,7 +36,7 @@ def __read_gflags_from_env__(): ...@@ -36,7 +36,7 @@ def __read_gflags_from_env__():
""" """
import sys import sys
import core import core
read_env_flags = ['use_pinned_memory'] read_env_flags = ['use_pinned_memory', 'check_nan_inf']
if core.is_compile_gpu(): if core.is_compile_gpu():
read_env_flags.append('fraction_of_gpu_memory_to_use') read_env_flags.append('fraction_of_gpu_memory_to_use')
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
......
...@@ -180,10 +180,22 @@ def save_inference_model(dirname, ...@@ -180,10 +180,22 @@ def save_inference_model(dirname,
:return: None :return: None
""" """
if isinstance(feeded_var_names, basestring):
feeded_var_names = [feeded_var_names]
else:
if not (bool(feeded_var_names) and all(
isinstance(name, basestring) for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
feeded_var_names = [feeded_var_names]
else:
if not (bool(target_vars) and all(
isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.")
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
if not isinstance(target_vars, list):
target_vars = [target_vars]
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
os.makedirs(dirname) os.makedirs(dirname)
......
...@@ -184,7 +184,7 @@ class LayerHelper(object): ...@@ -184,7 +184,7 @@ class LayerHelper(object):
self.append_op( self.append_op(
type=act_type, type=act_type,
inputs={"X": [input_var]}, inputs={"X": [input_var]},
outputs={"Y": [tmp]}, outputs={"Out": [tmp]},
attrs=act) attrs=act)
return tmp return tmp
......
...@@ -386,7 +386,8 @@ def square_error_cost(input, label, **kwargs): ...@@ -386,7 +386,8 @@ def square_error_cost(input, label, **kwargs):
square_out = helper.create_tmp_variable(dtype=input.dtype) square_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op( helper.append_op(
type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]}) type='square', inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out return square_out
...@@ -604,7 +605,7 @@ def sequence_pool(input, pool_type, **kwargs): ...@@ -604,7 +605,7 @@ def sequence_pool(input, pool_type, **kwargs):
sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2) 6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
Args: Args:
input(variable): The input variable which is a LoDTensor. input(variable): The input variable which is a LoDTensor.
pool_type (string): The pooling type of sequence_pool. pool_type (string): The pooling type of sequence_pool.
...@@ -616,7 +617,7 @@ def sequence_pool(input, pool_type, **kwargs): ...@@ -616,7 +617,7 @@ def sequence_pool(input, pool_type, **kwargs):
Examples: Examples:
.. code-block:: python .. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1], x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
avg_x = fluid.layers.sequence_pool(input=x, pool_type='average') avg_x = fluid.layers.sequence_pool(input=x, pool_type='average')
...@@ -654,7 +655,7 @@ def sequence_first_step(input, **kwargs): ...@@ -654,7 +655,7 @@ def sequence_first_step(input, **kwargs):
out.dim = [3, 1] out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0] with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1) out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
Args: Args:
input(variable): The input variable which is a LoDTensor. input(variable): The input variable which is a LoDTensor.
...@@ -664,7 +665,7 @@ def sequence_first_step(input, **kwargs): ...@@ -664,7 +665,7 @@ def sequence_first_step(input, **kwargs):
Examples: Examples:
.. code-block:: python .. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1], x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
x_first_step = fluid.layers.sequence_first_step(input=x) x_first_step = fluid.layers.sequence_first_step(input=x)
...@@ -687,7 +688,7 @@ def sequence_last_step(input, **kwargs): ...@@ -687,7 +688,7 @@ def sequence_last_step(input, **kwargs):
out.dim = [3, 1] out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0] with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
Args: Args:
input(variable): The input variable which is a LoDTensor. input(variable): The input variable which is a LoDTensor.
...@@ -697,7 +698,7 @@ def sequence_last_step(input, **kwargs): ...@@ -697,7 +698,7 @@ def sequence_last_step(input, **kwargs):
Examples: Examples:
.. code-block:: python .. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1], x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
x_last_step = fluid.layers.sequence_last_step(input=x) x_last_step = fluid.layers.sequence_last_step(input=x)
...@@ -1132,7 +1133,7 @@ def reduce_sum(input, dim=None, keep_dim=False): ...@@ -1132,7 +1133,7 @@ def reduce_sum(input, dim=None, keep_dim=False):
Returns: Returns:
Variable: The reduced Tensor variable. Variable: The reduced Tensor variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1176,7 +1177,7 @@ def reduce_mean(input, dim=None, keep_dim=False): ...@@ -1176,7 +1177,7 @@ def reduce_mean(input, dim=None, keep_dim=False):
Returns: Returns:
Variable: The reduced Tensor variable. Variable: The reduced Tensor variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
......
...@@ -10,13 +10,13 @@ class TestExp(OpTest): ...@@ -10,13 +10,13 @@ class TestExp(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
} }
self.outputs = {'Y': np.exp(self.inputs['X'])} self.outputs = {'Out': np.exp(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSigmoid(OpTest): class TestSigmoid(OpTest):
...@@ -25,13 +25,13 @@ class TestSigmoid(OpTest): ...@@ -25,13 +25,13 @@ class TestSigmoid(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
} }
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} self.outputs = {'Out': 1 / (1 + np.exp(-self.inputs['X']))}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008) self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestLogSigmoid(OpTest): class TestLogSigmoid(OpTest):
...@@ -40,13 +40,13 @@ class TestLogSigmoid(OpTest): ...@@ -40,13 +40,13 @@ class TestLogSigmoid(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
} }
self.outputs = {'Y': np.log(1 / (1 + np.exp(-self.inputs['X'])))} self.outputs = {'Out': np.log(1 / (1 + np.exp(-self.inputs['X'])))}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008) self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestTanh(OpTest): class TestTanh(OpTest):
...@@ -55,13 +55,13 @@ class TestTanh(OpTest): ...@@ -55,13 +55,13 @@ class TestTanh(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
} }
self.outputs = {'Y': np.tanh(self.inputs['X'])} self.outputs = {'Out': np.tanh(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestTanhShrink(OpTest): class TestTanhShrink(OpTest):
...@@ -70,13 +70,13 @@ class TestTanhShrink(OpTest): ...@@ -70,13 +70,13 @@ class TestTanhShrink(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 17]).astype("float32") 'X': np.random.uniform(0.1, 1, [10, 17]).astype("float32")
} }
self.outputs = {'Y': self.inputs['X'] - np.tanh(self.inputs['X'])} self.outputs = {'Out': self.inputs['X'] - np.tanh(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008) self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestHardShrink(OpTest): class TestHardShrink(OpTest):
...@@ -90,13 +90,13 @@ class TestHardShrink(OpTest): ...@@ -90,13 +90,13 @@ class TestHardShrink(OpTest):
t = np.copy(x) t = np.copy(x)
t[(t >= -threshold) & (t <= threshold)] = 0 t[(t >= -threshold) & (t <= threshold)] = 0
self.outputs = {'Y': t} self.outputs = {'Out': t}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.005) self.check_grad(['X'], 'Out', max_relative_error=0.005)
class TestSoftShrink(OpTest): class TestSoftShrink(OpTest):
...@@ -110,13 +110,13 @@ class TestSoftShrink(OpTest): ...@@ -110,13 +110,13 @@ class TestSoftShrink(OpTest):
y = np.copy(self.inputs['X']) y = np.copy(self.inputs['X'])
y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * ( y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * (
y - lambda_val) y - lambda_val)
self.outputs = {'Y': y} self.outputs = {'Out': y}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSqrt(OpTest): class TestSqrt(OpTest):
...@@ -125,13 +125,13 @@ class TestSqrt(OpTest): ...@@ -125,13 +125,13 @@ class TestSqrt(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
} }
self.outputs = {'Y': np.sqrt(self.inputs['X'])} self.outputs = {'Out': np.sqrt(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestAbs(OpTest): class TestAbs(OpTest):
...@@ -144,13 +144,13 @@ class TestAbs(OpTest): ...@@ -144,13 +144,13 @@ class TestAbs(OpTest):
# we should avoid this # we should avoid this
x[np.abs(x) < 0.005] = 0.02 x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Y': np.abs(self.inputs['X'])} self.outputs = {'Out': np.abs(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestCeil(OpTest): class TestCeil(OpTest):
...@@ -158,13 +158,13 @@ class TestCeil(OpTest): ...@@ -158,13 +158,13 @@ class TestCeil(OpTest):
self.op_type = "ceil" self.op_type = "ceil"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32") x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Y': np.ceil(self.inputs['X'])} self.outputs = {'Out': np.ceil(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestFloor(OpTest): class TestFloor(OpTest):
...@@ -173,13 +173,13 @@ class TestFloor(OpTest): ...@@ -173,13 +173,13 @@ class TestFloor(OpTest):
x = np.random.uniform(-1, 1, [4, 4]).astype("float32") x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x} self.inputs = {'X': x}
# numpy floor need +1 # numpy floor need +1
self.outputs = {'Y': np.floor(self.inputs['X']) + 1.0} self.outputs = {'Out': np.floor(self.inputs['X']) + 1.0}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestRound(OpTest): class TestRound(OpTest):
...@@ -187,13 +187,13 @@ class TestRound(OpTest): ...@@ -187,13 +187,13 @@ class TestRound(OpTest):
self.op_type = "round" self.op_type = "round"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32") x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Y': np.round(self.inputs['X'])} self.outputs = {'Out': np.round(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestRelu(OpTest): class TestRelu(OpTest):
...@@ -203,13 +203,13 @@ class TestRelu(OpTest): ...@@ -203,13 +203,13 @@ class TestRelu(OpTest):
# The same reason with TestAbs # The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02 x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} self.outputs = {'Out': np.maximum(self.inputs['X'], 0)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestBRelu(OpTest): class TestBRelu(OpTest):
...@@ -227,13 +227,13 @@ class TestBRelu(OpTest): ...@@ -227,13 +227,13 @@ class TestBRelu(OpTest):
t = np.copy(x) t = np.copy(x)
t[t < t_min] = t_min t[t < t_min] = t_min
t[t > t_max] = t_max t[t > t_max] = t_max
self.outputs = {'Y': t} self.outputs = {'Out': t}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02) self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestRelu6(OpTest): class TestRelu6(OpTest):
...@@ -248,14 +248,14 @@ class TestRelu6(OpTest): ...@@ -248,14 +248,14 @@ class TestRelu6(OpTest):
self.inputs = {'X': x} self.inputs = {'X': x}
self.attrs = {'threshold': threshold} self.attrs = {'threshold': threshold}
self.outputs = { self.outputs = {
'Y': np.minimum(np.maximum(self.inputs['X'], 0), threshold) 'Out': np.minimum(np.maximum(self.inputs['X'], 0), threshold)
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02) self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestSoftRelu(OpTest): class TestSoftRelu(OpTest):
...@@ -271,13 +271,13 @@ class TestSoftRelu(OpTest): ...@@ -271,13 +271,13 @@ class TestSoftRelu(OpTest):
t = np.copy(x) t = np.copy(x)
t[t < -threshold] = -threshold t[t < -threshold] = -threshold
t[t > threshold] = threshold t[t > threshold] = threshold
self.outputs = {'Y': np.log((np.exp(t) + 1))} self.outputs = {'Out': np.log((np.exp(t) + 1))}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02) self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestELU(OpTest): class TestELU(OpTest):
...@@ -290,27 +290,27 @@ class TestELU(OpTest): ...@@ -290,27 +290,27 @@ class TestELU(OpTest):
self.inputs = {'X': x} self.inputs = {'X': x}
self.attrs = {'alpha': alpha} self.attrs = {'alpha': alpha}
self.outputs = { self.outputs = {
'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) 'Out': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02) self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestReciprocal(OpTest): class TestReciprocal(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reciprocal" self.op_type = "reciprocal"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.outputs = {'Y': np.reciprocal(self.inputs['X'])} self.outputs = {'Out': np.reciprocal(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.01) self.check_grad(['X'], 'Out', max_relative_error=0.01)
class TestLog(OpTest): class TestLog(OpTest):
...@@ -319,13 +319,13 @@ class TestLog(OpTest): ...@@ -319,13 +319,13 @@ class TestLog(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
} }
self.outputs = {'Y': np.log(self.inputs['X'])} self.outputs = {'Out': np.log(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSquare(OpTest): class TestSquare(OpTest):
...@@ -334,13 +334,13 @@ class TestSquare(OpTest): ...@@ -334,13 +334,13 @@ class TestSquare(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
} }
self.outputs = {'Y': np.square(self.inputs['X'])} self.outputs = {'Out': np.square(self.inputs['X'])}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestPow(OpTest): class TestPow(OpTest):
...@@ -348,13 +348,13 @@ class TestPow(OpTest): ...@@ -348,13 +348,13 @@ class TestPow(OpTest):
self.op_type = "pow" self.op_type = "pow"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.attrs = {'factor': 3.0} self.attrs = {'factor': 3.0}
self.outputs = {'Y': np.power(self.inputs['X'], 3)} self.outputs = {'Out': np.power(self.inputs['X'], 3)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02) self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestSTanh(OpTest): class TestSTanh(OpTest):
...@@ -366,13 +366,13 @@ class TestSTanh(OpTest): ...@@ -366,13 +366,13 @@ class TestSTanh(OpTest):
scale_a = 2.0 / 3.0 scale_a = 2.0 / 3.0
scale_b = 1.7159 scale_b = 1.7159
self.attrs = {'scale_a': scale_a, 'scale_b': scale_b} self.attrs = {'scale_a': scale_a, 'scale_b': scale_b}
self.outputs = {'Y': scale_b * np.tanh(self.inputs['X'] * scale_a)} self.outputs = {'Out': scale_b * np.tanh(self.inputs['X'] * scale_a)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSoftplus(OpTest): class TestSoftplus(OpTest):
...@@ -381,13 +381,13 @@ class TestSoftplus(OpTest): ...@@ -381,13 +381,13 @@ class TestSoftplus(OpTest):
self.inputs = { self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float64") 'X': np.random.uniform(-1, 1, [11, 17]).astype("float64")
} }
self.outputs = {'Y': np.log(1 + np.exp(self.inputs['X']))} self.outputs = {'Out': np.log(1 + np.exp(self.inputs['X']))}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSoftsign(OpTest): class TestSoftsign(OpTest):
...@@ -397,14 +397,14 @@ class TestSoftsign(OpTest): ...@@ -397,14 +397,14 @@ class TestSoftsign(OpTest):
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
} }
self.outputs = { self.outputs = {
'Y': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X'])) 'Out': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X']))
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestThresholdedRelu(OpTest): class TestThresholdedRelu(OpTest):
...@@ -419,13 +419,13 @@ class TestThresholdedRelu(OpTest): ...@@ -419,13 +419,13 @@ class TestThresholdedRelu(OpTest):
self.inputs = {'X': X} self.inputs = {'X': X}
self.attrs = {'threshold': threshold} self.attrs = {'threshold': threshold}
self.outputs = {'Y': (X > threshold) * X} self.outputs = {'Out': (X > threshold) * X}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=self.relative_error) self.check_grad(['X'], 'Out', max_relative_error=self.relative_error)
class TestHardSigmoid(OpTest): class TestHardSigmoid(OpTest):
...@@ -447,13 +447,13 @@ class TestHardSigmoid(OpTest): ...@@ -447,13 +447,13 @@ class TestHardSigmoid(OpTest):
upper_threshold - 0.2 upper_threshold - 0.2
temp = X * slope + offset temp = X * slope + offset
self.outputs = {'Y': np.maximum(0.0, np.minimum(1.0, temp))} self.outputs = {'Out': np.maximum(0.0, np.minimum(1.0, temp))}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.002) self.check_grad(['X'], 'Out', max_relative_error=0.002)
class TestSwish(OpTest): class TestSwish(OpTest):
...@@ -462,13 +462,13 @@ class TestSwish(OpTest): ...@@ -462,13 +462,13 @@ class TestSwish(OpTest):
X = np.random.uniform(0.1, 1, [11, 17]).astype("float32") X = np.random.uniform(0.1, 1, [11, 17]).astype("float32")
self.inputs = {'X': X} self.inputs = {'X': X}
self.attrs = {'beta': 2.3} self.attrs = {'beta': 2.3}
self.outputs = {'Y': X * expit(self.attrs['beta'] * X)} self.outputs = {'Out': X * expit(self.attrs['beta'] * X)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008) self.check_grad(['X'], 'Out', max_relative_error=0.008)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -7,7 +7,7 @@ def fc(X, W, Y): ...@@ -7,7 +7,7 @@ def fc(X, W, Y):
ret_v = core.Net.create() ret_v = core.Net.create()
ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation")) ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation"))
ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y)) ret_v.append_op(Operator("sigmoid", X="pre_activation", Out=Y))
ret_v.complete_add_op(True) ret_v.complete_add_op(True)
return ret_v return ret_v
...@@ -30,7 +30,7 @@ Op(plain_net), inputs:{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]} ...@@ -30,7 +30,7 @@ Op(plain_net), inputs:{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]}
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}. Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}. Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(mul), inputs:{X[X], Y[W]}, outputs:{Out[pre_activation]}. Op(mul), inputs:{X[X], Y[W]}, outputs:{Out[pre_activation]}.
Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Y[fc.out]}. Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Out[fc.out]}.
''' '''
self.assertEqual(expected, "\n" + str(net)) self.assertEqual(expected, "\n" + str(net))
......
...@@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest): ...@@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest):
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") 'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
} }
self.outputs = { self.outputs = {
'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) 'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
} }
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y') self.check_grad(['X'], 'Out')
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册