Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/infiniop/ops/zeros/bang/zeros_bang.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __ZEROS_BANG_API_H__
#define __ZEROS_BANG_API_H__

#include "../../../elementwise/bang/elementwise_bang.h"

ELEMENTWISE_DESCRIPTOR(zeros, bang)

#endif // __ZEROS_BANG_API_H__
176 changes: 79 additions & 97 deletions src/infiniop/ops/zeros/bang/zeros_bang.mlu
Original file line number Diff line number Diff line change
@@ -1,103 +1,52 @@
#include "../../../devices/bang/common_bang.h"
#include "zeros_bang.h"

#include <algorithm>

__nram__ char zeros_nram_buffer[NRAM_MAX_SIZE];

__mlu_global__ void zerosKernel(uint8_t *__restrict__ output, size_t total_bytes) {
if (total_bytes == 0) {
return;
}

uint8_t *cache = reinterpret_cast<uint8_t *>((reinterpret_cast<size_t>(zeros_nram_buffer) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1));
size_t max_chunk = NRAM_MAX_SIZE - ALIGN_SIZE;

for (size_t start = taskId * max_chunk; start < total_bytes; start += taskDim * max_chunk) {
size_t current = std::min(max_chunk, total_bytes - start);
__bang_write_value(cache, current, static_cast<uint8_t>(0));
__memcpy(output + start, cache, current, NRAM2GDRAM);
}
}

static infiniStatus_t launchZeros(
int core_per_cluster,
int cluster_count,
cnrtQueue_t queue,
void *output,
size_t total_bytes) {
cnrtDim3_t kernel_dim;
kernel_dim.x = core_per_cluster;
kernel_dim.y = cluster_count;
kernel_dim.z = 1;

cnrtFunctionType_t func_type = total_bytes > 1024 * 1024 ? cnrtFuncTypeUnion1 : cnrtFuncTypeBlock;
zerosKernel<<<kernel_dim, func_type, queue>>>(reinterpret_cast<uint8_t *>(output), total_bytes);
CNRT_CHECK(cnrtQueueSync(queue));
return INFINI_STATUS_SUCCESS;
}
LAUNCH_ELEMENTWISE_KERNEL(Zeros)

namespace op::zeros::bang {

struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

static size_t storageSpanBytes(infiniopTensorDescriptor_t desc) {
if (desc->numel() == 0) {
return 0;
typedef struct ZerosOp {
static constexpr size_t num_inputs = 1;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchZerosKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} ZerosOp;

auto shape = desc->shape();
auto byte_strides = desc->getByteStrides();
size_t max_offset = 0;
for (size_t i = 0; i < shape.size(); ++i) {
max_offset += (shape[i] - 1) * static_cast<size_t>(byte_strides[i]);
}
return max_offset + infiniSizeOf(desc->dtype());
}
Descriptor::~Descriptor() = default;

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
std::vector<infiniopTensorDescriptor_t> input_descs) {
CHECK_OR_RETURN(!input_descs.empty(), INFINI_STATUS_BAD_PARAM);
auto input_desc = input_descs.at(0);
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {

auto dtype = output_desc->dtype();
CHECK_DTYPE(dtype,
INFINI_DTYPE_BYTE,
INFINI_DTYPE_BOOL,
INFINI_DTYPE_I8,
INFINI_DTYPE_I16,
INFINI_DTYPE_I32,
INFINI_DTYPE_I64,
INFINI_DTYPE_U8,
INFINI_DTYPE_U16,
INFINI_DTYPE_U32,
INFINI_DTYPE_U64,
INFINI_DTYPE_F16,
INFINI_DTYPE_F32,
INFINI_DTYPE_F64,
INFINI_DTYPE_BF16);
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto dtype = out_desc->dtype();

CHECK_SAME_SHAPE(output_desc->shape(), input_desc->shape());
CHECK_OR_RETURN(!output_desc->hasBroadcastDim(), INFINI_STATUS_BAD_TENSOR_STRIDES);
for (auto stride : output_desc->strides()) {
CHECK_OR_RETURN(stride >= 0, INFINI_STATUS_BAD_TENSOR_STRIDES);
}
const auto &x_desc = input_desc_vec.at(0);

CHECK_DTYPE(dtype,
INFINI_DTYPE_BYTE, // 1
INFINI_DTYPE_BOOL, // 2
INFINI_DTYPE_I8, // 3
INFINI_DTYPE_I16, // 4
INFINI_DTYPE_I32, // 5
INFINI_DTYPE_I64, // 6
INFINI_DTYPE_U8, // 7
INFINI_DTYPE_U16, // 8
INFINI_DTYPE_U32, // 9
INFINI_DTYPE_U64, // 10
INFINI_DTYPE_F16, // 12
INFINI_DTYPE_F32, // 13
INFINI_DTYPE_F64, // 14
INFINI_DTYPE_BF16 // 19
);

CHECK_SAME_SHAPE(out_desc->shape(), x_desc->shape());

CREATE_ELEMENTWISE_BANG_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)

auto handle_bang = reinterpret_cast<device::bang::Handle *>(handle);
*desc_ptr = new Descriptor(
storageSpanBytes(output_desc),
new Opaque{handle_bang->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}

Expand All @@ -106,18 +55,51 @@ infiniStatus_t Descriptor::calculate(
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
(void)workspace;
(void)workspace_size;
(void)inputs;
if (_storage_size == 0) {
return INFINI_STATUS_SUCCESS;
void *queue) const {

if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}

auto queue = reinterpret_cast<cnrtQueue_t>(stream);
int core_per_cluster = _opaque->internal->getCorePerCluster();
int cluster_count = _opaque->internal->getClusterCount();
return launchZeros(core_per_cluster, cluster_count, queue, output, _storage_size);
switch (_dtype) {
case INFINI_DTYPE_BYTE: // 1
return _device_info->calculate<ZerosOp, uint8_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_BOOL: // 2
return _device_info->calculate<ZerosOp, bool>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I8: // 3
return _device_info->calculate<ZerosOp, int8_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I16: // 4
return _device_info->calculate<ZerosOp, int16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I32: // 5
return _device_info->calculate<ZerosOp, int32_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_I64: // 6
return _device_info->calculate<ZerosOp, int64_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_U8: // 7
return _device_info->calculate<ZerosOp, uint8_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_U16: // 8
return _device_info->calculate<ZerosOp, uint16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_U32: // 9
return _device_info->calculate<ZerosOp, uint32_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_U64: // 10
return _device_info->calculate<ZerosOp, uint64_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F8: // 11
return INFINI_STATUS_NOT_IMPLEMENTED;
case INFINI_DTYPE_F16: // 12
return _device_info->calculate<ZerosOp, half>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32: // 13
return _device_info->calculate<ZerosOp, float>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F64: // 14
return _device_info->calculate<ZerosOp, double>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_C16: // 15
case INFINI_DTYPE_C32: // 16
case INFINI_DTYPE_C64: // 17
case INFINI_DTYPE_C128: // 18
return INFINI_STATUS_NOT_IMPLEMENTED;
case INFINI_DTYPE_BF16: // 19
return _device_info->calculate<ZerosOp, bfloat16_t>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}

} // namespace op::zeros::bang
39 changes: 39 additions & 0 deletions src/infiniop/ops/zeros/bang/zeros_bang_internal.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef __ZEROS_BANG_INTERNAL_H__
#define __ZEROS_BANG_INTERNAL_H__

#include "../../../elementwise/bang/elementwise_bang_kernel.h"

typedef struct ZerosOp {
public:
static constexpr size_t num_inputs = 1;
template <typename T>
__mlu_device__ void operator()(T *out, const T *input, const T *unused, size_t num_elements) const {
(void)input;
(void)unused;
if constexpr (std::is_same_v<T, double>) {
for (size_t i = 0; i < num_elements; ++i) {
out[i] = 0.0;
}
} else {
__bang_write_value(out, num_elements, static_cast<T>(0));
}
}
} ZerosOp;

LAUNCH_ELEMENTWISE_KERNEL_IMPL(Zeros, ZerosOp)

LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, uint8_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, bool)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, int8_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, int16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, int32_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, int64_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, uint16_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, uint32_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, uint64_t)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, half)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, float)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, double)
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, bfloat16_t)

#endif // __ZEROS_BANG_INTERNAL_H__
Loading