提交 cd8f3e9e 编写于 作者: D dzhwinter

operator module is done

上级 2ec589a2
......@@ -279,10 +279,12 @@ op_library(array_to_lod_tensor_op DEPS lod_rank_table_op)
op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
if (NOT WIN32)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
endif(NOT WIN32)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor)
......
......@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <iterator>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -92,8 +94,11 @@ class RowwiseTransformIterator;
template <typename T, typename DeviceContext>
class MidWiseTransformIterator;
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, typename T,
std::ptrdiff_t, typename T*, typename T&> {
public:
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
......@@ -124,7 +129,9 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
};
template <typename T>
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
class MidWiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T*, T&> {
public:
MidWiseTransformIterator(const T* ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
......@@ -473,8 +480,13 @@ void ElemwiseGradComputeNoBroadcast(
const framework::Tensor& dout, int axis, framework::Tensor* dx,
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
size_t N = static_cast<size_t>(framework::product(x_dim));
#if !defined(_WIN32)
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
#else
platform::ForRange<DeviceContext> for_range(
ctx.device_context<DeviceContext>(), N);
#endif // !_WIN32
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
......@@ -631,9 +643,13 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, int axis, Functor func,
framework::Tensor* z) {
#if !defined(_WIN32)
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), func);
#else
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.device_context<DeviceContext>(), func);
#endif // !_WIN32
auto x_dims = x->dims();
auto y_dims_untrimed = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
......
......@@ -41,13 +41,10 @@ math_library(cross_entropy)
math_library(cos_sim_functor)
math_library(depthwise_conv)
math_library(im2col)
if (NOT WIN32)
if (NOT WIN32) # windows do not support avx functions yet.
math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions)
else()
# windows do not support avx functions yet.
math_library(gru_compute DEPS math_function)
math_library(lstm_compute DEPS math_function)
endif (NOT WIN32)
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
......
......@@ -107,7 +107,11 @@ static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path,
static inline void* GetDsoHandleFromSearchPath(const std::string& search_root,
const std::string& dso_name,
bool throw_on_error = true) {
#if !defined(_WIN32)
int dynload_flags = RTLD_LAZY | RTLD_LOCAL;
#else
int dynload_flags = 0;
#endif // !_WIN32
void* dso_handle = nullptr;
std::string dlPath = dso_name;
......@@ -138,6 +142,11 @@ static inline void* GetDsoHandleFromSearchPath(const std::string& search_root,
"export LD_LIBRARY_PATH=... \n Note: After Mac OS 10.11, "
"using the DYLD_LIBRARY_PATH is impossible unless System "
"Integrity Protection (SIP) is disabled.";
#if !defined(_WIN32)
auto errorno = dlerror();
#else
auto errorno = GetLastError();
#endif // !_WIN32
if (throw_on_error) {
PADDLE_ENFORCE(nullptr != dso_handle, error_msg, dlPath, errorno);
} else if (nullptr == dso_handle) {
......
......@@ -47,7 +47,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/curand.h"
#if !defined(__APPLE__) and !defined(_WIN32)
#if !defined(__APPLE__) && !defined(_WIN32)
#include "paddle/fluid/platform/dynload/nccl.h"
#endif // __APPLE__
#endif // PADDLE_WITH_CUDA
......@@ -260,12 +260,6 @@ inline void throw_on_error(T e) {
} \
} while (false)
#define PADDLE_THROW_EOF() \
do { \
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
__LINE__); \
} while (false)
#else
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__);
#endif // REPLACE_ENFORCE_GLOG
......@@ -281,6 +275,12 @@ inline void throw_on_error(T e) {
#define PADDLE_ENFORCE(x, ...) x
#endif // !_WIN32
#define PADDLE_THROW_EOF() \
do { \
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
__LINE__); \
} while (false)
/*
* Some enforce helpers here, usage:
* int a = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册