提交 0f8c9dbe 编写于 作者: D dongzhihong

device context pointer

上级 2447c34a
...@@ -55,7 +55,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) ...@@ -55,7 +55,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(guassian_random_op SRCS guassain_random_op.cc guassian_random_op.cu) op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/gaussian_random_op.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/random_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,7 +22,7 @@ namespace operators { ...@@ -22,7 +22,7 @@ namespace operators {
template <typename T> template <typename T>
class GaussianRandomOpKernel<platform::CPUPlace, T> class GaussianRandomOpKernel<platform::CPUPlace, T>
: public framework::OpKernel { : public framework::OpKernel {
public: public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean"); auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std"); auto std = context.op_.GetAttr<T>("std");
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
}; };
class GaussianRandomOp : public framework::OperatorWithKernel { class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(
const std::vector<const framework::Tensor*>& inputs, const std::vector<const framework::Tensor*>& inputs,
const std::vector<framework::Tensor*>& outputs) const override { const std::vector<framework::Tensor*>& outputs) const override {
...@@ -54,7 +54,7 @@ protected: ...@@ -54,7 +54,7 @@ protected:
}; };
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
GaussianRandomOpMaker(framework::OpProto* proto, GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
...@@ -74,8 +74,7 @@ The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std) ...@@ -74,8 +74,7 @@ The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std)
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(gaussian_random, REGISTER_OP(gaussian_random, paddle::operators::GaussianRandomOp,
paddle::operators::GaussianRandomOp,
paddle::operators::GaussianRandomOpMaker); paddle::operators::GaussianRandomOpMaker);
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::CPUPlace, typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::CPUPlace,
......
#include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/guassian_random_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template<typename T> template <typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T> : public framework::OpKernel { class GaussianRandomOpKernel<platform::GPUPlace, T>
public: : public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override { void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean"); auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std"); auto std = context.op_.GetAttr<T>("std");
auto* output = context.Output(0)->GetMutable<framework::Tensor>(); auto* output = context.Output(0)->GetMutable<framework::Tensor>();
T* r = output->mutable_data<T>(context.GetPlace()); T* r = output->mutable_data<T>(context.GetPlace());
auto ctx = static_cast<const platform::GPUDeviceContext*> auto ctx =
(context.device_context_); static_cast<const platform::GPUDeviceContext*>(context.device_context_);
// generator need to modify context // generator need to modify context
auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator(); auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, r, framework::product(output->dims()), mean, std); curandGenerateNormal(g, r, framework::product(output->dims()), mean, std);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace, float> typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace,
RandomOpKernel_GPU_float; float>
RandomOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float); REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float);
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册