You need to sign in or sign up before continuing.
提交 0f8c9dbe 编写于 作者: D dongzhihong

device context pointer

上级 2447c34a
......@@ -55,7 +55,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(guassian_random_op SRCS guassain_random_op.cc guassian_random_op.cu)
op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
......
......@@ -12,9 +12,9 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gaussian_random_op.h"
#include "glog/logging.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/random_op.h"
namespace paddle {
namespace operators {
......@@ -22,7 +22,7 @@ namespace operators {
template <typename T>
class GaussianRandomOpKernel<platform::CPUPlace, T>
: public framework::OpKernel {
public:
public:
void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
......@@ -40,7 +40,7 @@ public:
};
class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
protected:
void InferShape(
const std::vector<const framework::Tensor*>& inputs,
const std::vector<framework::Tensor*>& outputs) const override {
......@@ -54,7 +54,7 @@ protected:
};
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
......@@ -74,8 +74,7 @@ The eqution : Out = GaussianRandom(Shape=(d0, d1, ...), Dtype, mean, std)
} // namespace operators
} // namespace paddle
REGISTER_OP(gaussian_random,
paddle::operators::GaussianRandomOp,
REGISTER_OP(gaussian_random, paddle::operators::GaussianRandomOp,
paddle::operators::GaussianRandomOpMaker);
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::CPUPlace,
......
#include "paddle/operators/random_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/guassian_random_op.h"
namespace paddle {
namespace operators {
template<typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T> : public framework::OpKernel {
public:
template <typename T>
class GaussianRandomOpKernel<platform::GPUPlace, T>
: public framework::OpKernel {
public:
void Compute(const framework::KernelContext& context) const override {
auto mean = context.op_.GetAttr<T>("mean");
auto std = context.op_.GetAttr<T>("std");
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
T* r = output->mutable_data<T>(context.GetPlace());
auto ctx = static_cast<const platform::GPUDeviceContext*>
(context.device_context_);
// generator need to modify context
auto ctx =
static_cast<const platform::GPUDeviceContext*>(context.device_context_);
// generator need to modify context
auto g = const_cast<platform::GPUDeviceContext*>(ctx)->RandGenerator();
curandGenerateNormal(g, r, framework::product(output->dims()), mean, std);
}
};
} // namespace operators
} // namespace paddle
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace, float>
RandomOpKernel_GPU_float;
typedef paddle::operators::GaussianRandomOpKernel<paddle::platform::GPUPlace,
float>
RandomOpKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(gaussian_random, GaussianRandomOpKernel_GPU_float);
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册