diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index c9168fdf53f7bc3d76f8bd8eedc35b081d8397b3..922e6904ed7bb7c9db308cd20ff5bbf7e5261948 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -20,6 +20,34 @@ limitations under the License. */ namespace paddle { namespace operators { +template +struct SameDimsElemwiseAdd< + platform::CPUDeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + auto blas = math::GetBlas(ctx); + blas.VADD(x->numel(), x->data(), y->data(), z->data()); + } +}; + +template +struct SameDimsElemwiseAdd< + platform::CPUDeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_y = framework::EigenVector::Flatten(*y); + auto eigen_z = framework::EigenVector::Flatten(*z); + auto &place = *ctx.template device_context() + .eigen_device(); + eigen_z.device(place) = eigen_x + eigen_y; + } +}; + class ElementwiseAddOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Add"; } diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 15b4bff0b783f8d9a942b49b67c8a13c8e9dbf3f..de121b3786f3c9cc94a0c3dab789f372d9e17e72 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -11,13 +11,84 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; namespace plat = paddle::platform; +namespace paddle { +namespace operators { + +template +struct SameDimsElemwiseAdd { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + AddRangeFunctor functor(x->data(), y->data(), z->data()); + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, + x->numel()); + for_range(functor); + } +}; + +template <> +struct SameDimsElemwiseAdd { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + auto size = x->numel(); + dim3 gird_size = dim3( + (size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + const half* x2 = + reinterpret_cast(x->data()); + const half* y2 = + reinterpret_cast(y->data()); + half* z2 = reinterpret_cast(z->data()); + SameDimsElemwiseAddCUDAKernel<<< + gird_size, block_size, 0, + ctx.template device_context().stream()>>>( + x2, y2, z2, size); + } +}; + +template +static __global__ void SimpleElemwiseAddGradCUDAKernel(const T* dout, + int64_t size, T* dx, + T* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + dx[col] = dout[col]; + dy[col] = dout[col]; + col += blockDim.x * gridDim.x; + } +} + +template +typename std::enable_if< + std::is_same::value>::type +elementwise_add_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + auto size = x->numel(); + dim3 gird_size = + dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + SimpleElemwiseAddGradCUDAKernel< + T><<().stream()>>>( + dout->data(), size, dx->mutable_data(ctx.GetPlace()), + dy->mutable_data(ctx.GetPlace())); +} + +} // namespace operators +} // namespace paddle REGISTER_OP_CUDA_KERNEL( elementwise_add, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index 7f8b0ffe92fd40d7944f05282c4edc8271547e00..315a22903147dcf86fa19a791cfb4b894d4c72f9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -11,22 +11,15 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #pragma once -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" - namespace paddle { namespace operators { -template -struct AddFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a + b; } -}; - template void default_elementwise_add(const framework::ExecutionContext &ctx, const framework::Tensor *x, @@ -36,31 +29,12 @@ void default_elementwise_add(const framework::ExecutionContext &ctx, AddFunctor(), z); } -template -typename std::enable_if< - std::is_floating_point::value && - std::is_same::value>::type -elementwise_add_same_dims(const framework::ExecutionContext &ctx, - const framework::Tensor *x, - const framework::Tensor *y, framework::Tensor *z) { - auto blas = math::GetBlas(ctx); - blas.VADD(x->numel(), x->data(), y->data(), z->data()); -} - -template -typename std::enable_if< - !std::is_floating_point::value || - !std::is_same::value>::type -elementwise_add_same_dims(const framework::ExecutionContext &ctx, - const framework::Tensor *x, - const framework::Tensor *y, framework::Tensor *z) { - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_y = framework::EigenVector::Flatten(*y); - auto eigen_z = framework::EigenVector::Flatten(*z); - - auto &place = *ctx.template device_context().eigen_device(); - eigen_z.device(place) = eigen_x + eigen_y; -} +template +struct SameDimsElemwiseAdd { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z); +}; template class ElementwiseAddKernel : public framework::OpKernel { @@ -69,12 +43,11 @@ class ElementwiseAddKernel : public framework::OpKernel { auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - auto dims_equal = x->dims() == y->dims(); if (dims_equal) { - elementwise_add_same_dims(ctx, x, y, z); + SameDimsElemwiseAdd same_dims_add; + same_dims_add(ctx, x, y, z); } else { default_elementwise_add(ctx, x, y, z); } @@ -112,7 +85,6 @@ elementwise_add_grad(const framework::ExecutionContext &ctx, const framework::Tensor *dout, framework::Tensor *dx, framework::Tensor *dy) { auto blas = math::GetBlas(ctx); - if (dx) { blas.VCOPY(dout->numel(), dout->data(), dx->mutable_data(ctx.GetPlace())); @@ -126,8 +98,8 @@ elementwise_add_grad(const framework::ExecutionContext &ctx, template typename std::enable_if< - !std::is_floating_point::value || - !std::is_same::value>::type + !std::is_floating_point::value && + std::is_same::value>::type elementwise_add_grad(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, const framework::Tensor *out, @@ -136,6 +108,18 @@ elementwise_add_grad(const framework::ExecutionContext &ctx, default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } +#ifdef PADDLE_WITH_CUDA +// cuda definition +template +typename std::enable_if< + std::is_same::value>::type +elementwise_add_grad(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, framework::Tensor *dx, + framework::Tensor *dy); +#endif + template class ElementwiseAddGradKernel : public ElemwiseGradKernel { public: @@ -151,8 +135,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel { auto *out = dout; auto *x = dout, *y = dout; - if (platform::is_cpu_place(ctx.GetPlace()) && dx != nullptr && - dy != nullptr && (dx->dims() == dy->dims())) { + if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { elementwise_add_grad(ctx, x, y, out, dout, dx, dy); } else { default_elementwise_add_grad(ctx, x, y, out, dout, dx, diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 2002e8f31cb6000612cbd30fb82a3da5762daa3c..000055a4b17225a4df69623d8b7d60c5d34ee31c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -20,6 +20,34 @@ limitations under the License. */ namespace paddle { namespace operators { +template +struct SameDimsElemwiseDiv< + platform::CPUDeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + auto blas = math::GetBlas(ctx); + blas.VDIV(x->numel(), x->data(), y->data(), z->data()); + } +}; + +template +struct SameDimsElemwiseDiv< + platform::CPUDeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_y = framework::EigenVector::Flatten(*y); + auto eigen_z = framework::EigenVector::Flatten(*z); + auto &place = *ctx.template device_context() + .eigen_device(); + eigen_z.device(place) = eigen_x / eigen_y; + } +}; + class ElementwiseDivOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Div"; } diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 4cd17b94e5dcd10e390a769f4cf77b3b772a7a86..b1698491180d773dc35d09d8d7ec642d3ed914fe 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -12,9 +12,87 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +template +struct SameDimsElemwiseDiv { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + DivRangeFunctor functor(x->data(), y->data(), z->data()); + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, + x->numel()); + for_range(functor); + } +}; + +template <> +struct SameDimsElemwiseDiv { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + auto size = x->numel(); + dim3 gird_size = dim3( + (size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + const half* x2 = + reinterpret_cast(x->data()); + const half* y2 = + reinterpret_cast(y->data()); + half* z2 = reinterpret_cast(z->data()); + SameDimsElemwiseDivCUDAKernel<<< + gird_size, block_size, 0, + ctx.template device_context().stream()>>>( + x2, y2, z2, size); + } +}; + +template +static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, + const T* out, + const T* dout, + int64_t size, T* dx, + T* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + T o = dout[col]; + dx[col] = o / y[col]; + dy[col] = -o * out[col] / y[col]; + col += blockDim.x * gridDim.x; + } +} + +template +typename std::enable_if< + std::is_same::value>::type +elementwise_div_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + auto size = x->numel(); + dim3 gird_size = + dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + SimpleElemwiseDivGradCUDAKernel< + T><<().stream()>>>( + x->data(), y->data(), out->data(), dout->data(), size, + dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); +} + +} // namespace operators +} // namespace paddle REGISTER_OP_CUDA_KERNEL( elementwise_div, diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index a1c5684ea800a93ed7a56fa5d99b947691cd4488..3c460242f3d871cf3415ede203267a7928494678 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -17,16 +17,29 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" namespace paddle { namespace operators { -template -struct DivFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a / b; } +template +void default_elementwise_div(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + DivFunctor(), z); +} + +template +struct SameDimsElemwiseDiv { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z); }; template @@ -36,11 +49,15 @@ class ElementwiseDivKernel : public framework::OpKernel { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - DivFunctor(), z); + + auto dims_equal = x->dims() == y->dims(); + if (dims_equal) { + SameDimsElemwiseDiv same_dims_div; + same_dims_div(ctx, x, y, z); + } else { + default_elementwise_div(ctx, x, y, z); + } } }; @@ -63,6 +80,31 @@ struct DivDoubleDY { } }; +template +typename std::enable_if< + std::is_same::value>::type +elementwise_div_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, DivGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX(), DivGradDY()); +} + +#ifdef PADDLE_WITH_CUDA +// cuda definition +template +typename std::enable_if< + std::is_same::value>::type +elementwise_div_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy); +#endif + template class ElementwiseDivGradKernel : public ElemwiseGradKernel { public: @@ -76,11 +118,15 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - auto* x = dout; // Fake x, not used - ElemwiseGradCompute, DivGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX(), DivGradDY()); + if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { + elementwise_div_grad(ctx, x, y, out, dout, dx, dy); + } else { + ElemwiseGradCompute, DivGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX(), + DivGradDY()); + } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 0998b27ea11bec086cdfe580519b3026a5834074..d843fb4fd11444949eeb2e331c1cc41335599935 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -20,6 +20,34 @@ limitations under the License. */ namespace paddle { namespace operators { +template +struct SameDimsElemwiseMul< + platform::CPUDeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + auto blas = math::GetBlas(ctx); + blas.VMUL(x->numel(), x->data(), y->data(), z->data()); + } +}; + +template +struct SameDimsElemwiseMul< + platform::CPUDeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_y = framework::EigenVector::Flatten(*y); + auto eigen_z = framework::EigenVector::Flatten(*z); + auto &place = *ctx.template device_context() + .eigen_device(); + eigen_z.device(place) = eigen_x * eigen_y; + } +}; + class ElementwiseMulOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Mul"; } diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index d3c0dcb40958c21d96e266425b501dcd763b8f3a..4814cb144f057d4cb76b416e896e24ea227e92a2 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -13,15 +13,49 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/platform/float16.h" -#define TILE_SIZE 512 namespace ops = paddle::operators; namespace plat = paddle::platform; namespace paddle { namespace operators { +template +struct SameDimsElemwiseMul { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + MulRangeFunctor functor(x->data(), y->data(), z->data()); + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, + x->numel()); + for_range(functor); + } +}; + +template <> +struct SameDimsElemwiseMul { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + auto size = x->numel(); + dim3 gird_size = dim3( + (size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + const half* x2 = + reinterpret_cast(x->data()); + const half* y2 = + reinterpret_cast(y->data()); + half* z2 = reinterpret_cast(z->data()); + SameDimsElemwiseMulCUDAKernel<<< + gird_size, block_size, 0, + ctx.template device_context().stream()>>>( + x2, y2, z2, size); + } +}; + template static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, const T* out, @@ -38,40 +72,24 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, } } -template -class ElementwiseMulGradKernel - : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* out = dout; // out is not necessary - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); - - if (x->dims() == y->dims() && dx && dy) { - dim3 block_size = dim3(TILE_SIZE, 1); - auto size = x->numel(); - dim3 gird_size = dim3((size + TILE_SIZE - 1) / TILE_SIZE, 1); - SimpleElemwiseMulGradCUDAKernel<<< - gird_size, block_size, 0, - ctx.template device_context().stream()>>>( - x->data(), y->data(), out->data(), dout->data(), size, - dx->mutable_data(ctx.GetPlace()), - dy->mutable_data(ctx.GetPlace())); - return; - } else { - ElemwiseGradCompute, - MulGradDY>(ctx, *x, *y, *out, *dout, axis, dx, dy, - MulGradDX(), MulGradDY()); - } - } -}; +template +typename std::enable_if< + std::is_same::value>::type +elementwise_mul_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + auto size = x->numel(); + dim3 gird_size = + dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + SimpleElemwiseMulGradCUDAKernel< + T><<().stream()>>>( + x->data(), y->data(), out->data(), dout->data(), size, + dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 581caad62ed5d382af8957631ff8dbdbc401b1cb..49f0e305b60cc51ea21e283af4aea1a8b44470d5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -14,17 +14,13 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { -template -struct MulFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a * b; } -}; - template void default_elementwise_mul(const framework::ExecutionContext& ctx, const framework::Tensor* x, @@ -33,32 +29,12 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx, ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, MulFunctor(), z); } - -template -typename std::enable_if< - std::is_floating_point::value && - std::is_same::value>::type -elementwise_mul_same_dims(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { - auto blas = math::GetBlas(ctx); - blas.VMUL(x->numel(), x->data(), y->data(), z->data()); -} - -template -typename std::enable_if< - !std::is_floating_point::value || - !std::is_same::value>::type -elementwise_mul_same_dims(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_y = framework::EigenVector::Flatten(*y); - auto eigen_z = framework::EigenVector::Flatten(*z); - - auto& place = *ctx.template device_context().eigen_device(); - eigen_z.device(place) = eigen_x * eigen_y; -} +template +struct SameDimsElemwiseMul { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z); +}; template class ElementwiseMulKernel : public framework::OpKernel { @@ -92,7 +68,8 @@ class ElementwiseMulKernel : public framework::OpKernel { z->mutable_data(ctx.GetPlace()); if (x.numel() == y->numel()) { - elementwise_mul_same_dims(ctx, &x, y, z); + SameDimsElemwiseMul same_dims_mul; + same_dims_mul(ctx, &x, y, z); } else { default_elementwise_mul(ctx, &x, y, z); } @@ -109,6 +86,31 @@ struct MulGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } }; +template +typename std::enable_if< + std::is_same::value>::type +elementwise_mul_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, MulGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), MulGradDY()); +} + +#ifdef PADDLE_WITH_CUDA +// cuda definition +template +typename std::enable_if< + std::is_same::value>::type +elementwise_mul_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy); +#endif + template class ElementwiseMulGradKernel : public ElemwiseGradKernel { public: @@ -123,8 +125,13 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - ElemwiseGradCompute, MulGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), MulGradDY()); + if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { + elementwise_mul_grad(ctx, x, y, out, dout, dx, dy); + } else { + ElemwiseGradCompute, MulGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), + MulGradDY()); + } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..263f62255481901f0b0df7210d2ea8e3adbaaae3 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/hostdevice.h" + +#define PADDLE_CUDA_THREAD_SIZE 512 + +#ifdef PADDLE_WITH_CUDA +#include +#endif // PADDLE_WITH_CUDA + +#ifdef PADDLE_CUDA_FP16 +#include +#endif + +#if CUDA_VERSION < 9000 +#define __h2div h2div +#endif + +namespace paddle { +namespace operators { + +#define DEFINE_SIMPLE_BINARY_FUNCTOR(Func, expr) \ + template \ + struct Func##Functor { \ + inline HOSTDEVICE T operator()(const T& a, const T& b) const { \ + return a expr b; \ + } \ + }; + +DEFINE_SIMPLE_BINARY_FUNCTOR(Add, +) +DEFINE_SIMPLE_BINARY_FUNCTOR(Sub, -) +DEFINE_SIMPLE_BINARY_FUNCTOR(Mul, *) +DEFINE_SIMPLE_BINARY_FUNCTOR(Div, /) +#undef DEFINE_SIMPLE_BINARY_FUNCTOR + +#define DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Func, expr) \ + template \ + struct Func##RangeFunctor { \ + Func##RangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} \ + inline HOSTDEVICE void operator()(size_t id) const { \ + z_[id] = x_[id] expr y_[id]; \ + } \ + const T* x_; \ + const T* y_; \ + T* z_; \ + }; +DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Add, +) +DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Sub, -) +DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Mul, *) +DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Div, /) +#undef DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR + +#ifdef PADDLE_CUDA_FP16 +inline DEVICE half2 half2_add(const half2& a, const half2& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hadd2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 + b1; + float r2 = a2 + b2; + return __floats2half2_rn(r1, r2); +#endif +} + +inline DEVICE half2 half2_sub(const half2& a, const half2& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hsub2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 - b1; + float r2 = a2 - b2; + return __floats2half2_rn(r1, r2); +#endif +} + +inline DEVICE half2 half2_mul(const half2& a, const half2& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __hmul2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 * b1; + float r2 = a2 * b2; + return __floats2half2_rn(r1, r2); +#endif +} + +inline DEVICE half2 half2_div(const half2& a, const half2& b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 + return __h2div(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 / b1; + float r2 = a2 / b2; + return __floats2half2_rn(r1, r2); +#endif +} + +#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \ + template \ + __global__ void SameDimsElemwise##Func##CUDAKernel(const T* x, const T* y, \ + T* z, int64_t size) { \ + int col = blockIdx.x * blockDim.x + threadIdx.x; \ + while (col < size) { \ + z[col] = x[col] expr y[col]; \ + col += blockDim.x * gridDim.x; \ + } \ + } \ + template <> \ + inline __global__ void SameDimsElemwise##Func##CUDAKernel( \ + const half* x, const half* y, half* z, int64_t size) { \ + int start = threadIdx.x + blockDim.x * blockIdx.x; \ + int stride = blockDim.x * gridDim.x; \ + int n2 = size / 2; \ + const half2* x2 = reinterpret_cast(x); \ + const half2* y2 = reinterpret_cast(y); \ + half2* z2 = reinterpret_cast(z); \ + for (int i = start; i < n2; i += stride) { \ + z2[i] = FP16Function(x2[i], y2[i]); \ + } \ + if (start == 0 && (size % 2)) { \ + z[size - 1] = __float2half(__half2float(x[size - 1]) \ + expr __half2float(y[size - 1])); \ + } \ + } +DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Add, +, half2_add) +DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Sub, -, half2_sub) +DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Mul, *, half2_mul) +DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Div, /, half2_div) +#undef DEFINE_SIMPLE_CUDA_BINARY_KERNEL + +#endif // PADDLE_CUDA_FP16 + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc index 1692a8c2f235cb0d28dd0f53986aa69b03f0e880..48979348218c7d04dc1f14b79ba39e6bc15a5b59 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -20,6 +20,33 @@ limitations under the License. */ namespace paddle { namespace operators { +template +struct SameDimsElemwiseSub< + platform::CPUDeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + auto blas = math::GetBlas(ctx); + blas.VSUB(x->numel(), x->data(), y->data(), z->data()); + } +}; + +template +struct SameDimsElemwiseSub< + platform::CPUDeviceContext, T, + typename std::enable_if::value>::type> { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_y = framework::EigenVector::Flatten(*y); + auto eigen_z = framework::EigenVector::Flatten(*z); + auto &place = *ctx.template device_context() + .eigen_device(); + eigen_z.device(place) = eigen_x - eigen_y; + } +}; class ElementwiseSubOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Sub"; } diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 52fad7fd04b0083c81089899d4dab80853441ca7..7ff72028091ed78e7ca5c27d2b8bb362c12fd152 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -11,10 +11,85 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; +namespace plat = paddle::platform; + +namespace paddle { +namespace operators { + +template +struct SameDimsElemwiseSub { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + SubRangeFunctor functor(x->data(), y->data(), z->data()); + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, + x->numel()); + for_range(functor); + } +}; + +template <> +struct SameDimsElemwiseSub { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z) { + auto size = x->numel(); + dim3 gird_size = dim3( + (size / 2 + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + const half* x2 = + reinterpret_cast(x->data()); + const half* y2 = + reinterpret_cast(y->data()); + half* z2 = reinterpret_cast(z->data()); + SameDimsElemwiseSubCUDAKernel<<< + gird_size, block_size, 0, + ctx.template device_context().stream()>>>( + x2, y2, z2, size); + } +}; + +template +static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout, + int64_t size, T* dx, + T* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + dx[col] = dout[col]; + dy[col] = -dout[col]; + col += blockDim.x * gridDim.x; + } +} + +template +typename std::enable_if< + std::is_same::value>::type +elementwise_sub_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + auto size = x->numel(); + dim3 gird_size = + dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + SimpleElemwiseSubGradCUDAKernel< + T><<().stream()>>>( + dout->data(), size, dx->mutable_data(ctx.GetPlace()), + dy->mutable_data(ctx.GetPlace())); +} + +} // namespace operators +} // namespace paddle REGISTER_OP_CUDA_KERNEL( elementwise_sub, diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 5049d587b582a71981f45a72dc5bfc133dadb52d..1a64a4535985e2d4ba159ec7de6a0ffe3e657369 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -14,14 +14,27 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { -template -struct SubFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a - b; } +template +void default_elementwise_sub(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + SubFunctor(), z); +} + +template +struct SameDimsElemwiseSub { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z); }; template @@ -31,11 +44,15 @@ class ElementwiseSubKernel : public framework::OpKernel { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - SubFunctor(), z); + + auto dims_equal = x->dims() == y->dims(); + if (dims_equal) { + SameDimsElemwiseSub same_dims_sub; + same_dims_sub(ctx, x, y, z); + } else { + default_elementwise_sub(ctx, x, y, z); + } } }; @@ -49,6 +66,31 @@ struct SubGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; } }; +template +typename std::enable_if< + std::is_same::value>::type +elementwise_sub_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + ElemwiseExplicitGradCompute, SubGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), SubGradDY()); +} + +#ifdef PADDLE_WITH_CUDA +// cuda definition +template +typename std::enable_if< + std::is_same::value>::type +elementwise_sub_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, framework::Tensor* dx, + framework::Tensor* dy); +#endif + template class ElementwiseSubGradKernel : public ElemwiseGradKernel { public: @@ -63,9 +105,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel { // skip out, x, y auto* out = dout; auto *x = dout, *y = dout; - - ElemwiseExplicitGradCompute, SubGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), SubGradDY()); + if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { + elementwise_sub_grad(ctx, x, y, out, dout, dx, dy); + } else { + ElemwiseExplicitGradCompute, SubGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), + SubGradDY()); + } } }; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index a15dab935552c6e93fd9c0d9963985d6ea024f35..b0148a705542378ec670292c1d305d8d434c35c4 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -159,9 +159,15 @@ class Blas { template void VADD(int n, const T* x, const T* y, T* z) const; + template + void VSUB(int n, const T* x, const T* y, T* z) const; + template void VMUL(int n, const T* x, const T* y, T* z) const; + template + void VDIV(int n, const T* x, const T* y, T* z) const; + template void VCOPY(int n, const T* x, T* y) const; @@ -275,11 +281,21 @@ class BlasT : private Blas { Base()->template VADD(args...); } + template + void VSUB(ARGS... args) const { + Base()->template VSUB(args...); + } + template void VMUL(ARGS... args) const { Base()->template VMUL(args...); } + template + void VDIV(ARGS... args) const { + Base()->template VDIV(args...); + } + template void VCOPY(ARGS... args) const { Base()->template VCOPY(args...); diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index e2620bcfd9298f38f887f8a5b35aa8efba6b7053..817429be4429ca1c87443ba9688fc45ef5a8ab79 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -99,11 +99,21 @@ struct CBlas { platform::dynload::vsAdd(args...); } + template + static void VSUB(ARGS... args) { + platform::dynload::vsSub(args...); + } + template static void VMUL(ARGS... args) { platform::dynload::vsMul(args...); } + template + static void VDIV(ARGS... args) { + platform::dynload::vsDiv(args...); + } + template static void VEXP(ARGS... args) { platform::dynload::vsExp(args...); @@ -210,11 +220,21 @@ struct CBlas { platform::dynload::vdAdd(args...); } + template + static void VSUB(ARGS... args) { + platform::dynload::vdSub(args...); + } + template static void VMUL(ARGS... args) { platform::dynload::vdMul(args...); } + template + static void VDIV(ARGS... args) { + platform::dynload::vdDiv(args...); + } + template static void VEXP(ARGS... args) { platform::dynload::vdExp(args...); @@ -443,6 +463,20 @@ void Blas::VADD(int n, const T *x, const T *y, #endif } +template <> +template +void Blas::VSUB(int n, const T *x, const T *y, + T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VSUB(n, x, y, z); +#else + // try to find if openblas support vsub + for (int i = 0; i < n; ++i) { + z[i] = x[i] - y[i]; + } +#endif +} + template <> template void Blas::VMUL(int n, const T *x, const T *y, @@ -457,6 +491,20 @@ void Blas::VMUL(int n, const T *x, const T *y, #endif } +template <> +template +void Blas::VDIV(int n, const T *x, const T *y, + T *z) const { +#ifdef PADDLE_WITH_MKLML + CBlas::VDIV(n, x, y, z); +#else + // try to find if openblas support vdiv + for (int i = 0; i < n; ++i) { + z[i] = x[i] / y[i]; + } +#endif +} + template <> template void Blas::VEXP(int n, const T *x, T *y) const { diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 5070be43756fa0a0a08a410fcfcdbadaf751c424..839dcd87f57e1f9b8d8af751f2ef274f0f54b2bb 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -76,8 +73,12 @@ extern void* mklml_dso_handle; __macro(cblas_dscal); \ __macro(vsAdd); \ __macro(vdAdd); \ + __macro(vsSub); \ + __macro(vdSub); \ __macro(vsMul); \ __macro(vdMul); \ + __macro(vsDiv); \ + __macro(vdDiv); \ __macro(vsExp); \ __macro(vdExp); \ __macro(vsSqr); \