diff --git a/src/infiniop/ops/zeros/bang/zeros_bang.h b/src/infiniop/ops/zeros/bang/zeros_bang.h index e69de29bb..bc9bca190 100644 --- a/src/infiniop/ops/zeros/bang/zeros_bang.h +++ b/src/infiniop/ops/zeros/bang/zeros_bang.h @@ -0,0 +1,8 @@ +#ifndef __ZEROS_BANG_API_H__ +#define __ZEROS_BANG_API_H__ + +#include "../../../elementwise/bang/elementwise_bang.h" + +ELEMENTWISE_DESCRIPTOR(zeros, bang) + +#endif // __ZEROS_BANG_API_H__ diff --git a/src/infiniop/ops/zeros/bang/zeros_bang.mlu b/src/infiniop/ops/zeros/bang/zeros_bang.mlu index b90a9e5e7..0698f4b30 100644 --- a/src/infiniop/ops/zeros/bang/zeros_bang.mlu +++ b/src/infiniop/ops/zeros/bang/zeros_bang.mlu @@ -1,103 +1,52 @@ -#include "../../../devices/bang/common_bang.h" #include "zeros_bang.h" -#include - -__nram__ char zeros_nram_buffer[NRAM_MAX_SIZE]; - -__mlu_global__ void zerosKernel(uint8_t *__restrict__ output, size_t total_bytes) { - if (total_bytes == 0) { - return; - } - - uint8_t *cache = reinterpret_cast((reinterpret_cast(zeros_nram_buffer) + ALIGN_SIZE - 1) & ~(ALIGN_SIZE - 1)); - size_t max_chunk = NRAM_MAX_SIZE - ALIGN_SIZE; - - for (size_t start = taskId * max_chunk; start < total_bytes; start += taskDim * max_chunk) { - size_t current = std::min(max_chunk, total_bytes - start); - __bang_write_value(cache, current, static_cast(0)); - __memcpy(output + start, cache, current, NRAM2GDRAM); - } -} - -static infiniStatus_t launchZeros( - int core_per_cluster, - int cluster_count, - cnrtQueue_t queue, - void *output, - size_t total_bytes) { - cnrtDim3_t kernel_dim; - kernel_dim.x = core_per_cluster; - kernel_dim.y = cluster_count; - kernel_dim.z = 1; - - cnrtFunctionType_t func_type = total_bytes > 1024 * 1024 ? cnrtFuncTypeUnion1 : cnrtFuncTypeBlock; - zerosKernel<<>>(reinterpret_cast(output), total_bytes); - CNRT_CHECK(cnrtQueueSync(queue)); - return INFINI_STATUS_SUCCESS; -} +LAUNCH_ELEMENTWISE_KERNEL(Zeros) namespace op::zeros::bang { -struct Descriptor::Opaque { - std::shared_ptr internal; -}; - -Descriptor::~Descriptor() { - delete _opaque; -} - -static size_t storageSpanBytes(infiniopTensorDescriptor_t desc) { - if (desc->numel() == 0) { - return 0; +typedef struct ZerosOp { + static constexpr size_t num_inputs = 1; + template + static infiniStatus_t launch(Args... args) { + launchZerosKernel(args...); + return INFINI_STATUS_SUCCESS; } +} ZerosOp; - auto shape = desc->shape(); - auto byte_strides = desc->getByteStrides(); - size_t max_offset = 0; - for (size_t i = 0; i < shape.size(); ++i) { - max_offset += (shape[i] - 1) * static_cast(byte_strides[i]); - } - return max_offset + infiniSizeOf(desc->dtype()); -} +Descriptor::~Descriptor() = default; infiniStatus_t Descriptor::create( - infiniopHandle_t handle, + infiniopHandle_t handle_, Descriptor **desc_ptr, - infiniopTensorDescriptor_t output_desc, - std::vector input_descs) { - CHECK_OR_RETURN(!input_descs.empty(), INFINI_STATUS_BAD_PARAM); - auto input_desc = input_descs.at(0); + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { - auto dtype = output_desc->dtype(); - CHECK_DTYPE(dtype, - INFINI_DTYPE_BYTE, - INFINI_DTYPE_BOOL, - INFINI_DTYPE_I8, - INFINI_DTYPE_I16, - INFINI_DTYPE_I32, - INFINI_DTYPE_I64, - INFINI_DTYPE_U8, - INFINI_DTYPE_U16, - INFINI_DTYPE_U32, - INFINI_DTYPE_U64, - INFINI_DTYPE_F16, - INFINI_DTYPE_F32, - INFINI_DTYPE_F64, - INFINI_DTYPE_BF16); + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); - CHECK_SAME_SHAPE(output_desc->shape(), input_desc->shape()); - CHECK_OR_RETURN(!output_desc->hasBroadcastDim(), INFINI_STATUS_BAD_TENSOR_STRIDES); - for (auto stride : output_desc->strides()) { - CHECK_OR_RETURN(stride >= 0, INFINI_STATUS_BAD_TENSOR_STRIDES); - } + const auto &x_desc = input_desc_vec.at(0); + + CHECK_DTYPE(dtype, + INFINI_DTYPE_BYTE, // 1 + INFINI_DTYPE_BOOL, // 2 + INFINI_DTYPE_I8, // 3 + INFINI_DTYPE_I16, // 4 + INFINI_DTYPE_I32, // 5 + INFINI_DTYPE_I64, // 6 + INFINI_DTYPE_U8, // 7 + INFINI_DTYPE_U16, // 8 + INFINI_DTYPE_U32, // 9 + INFINI_DTYPE_U64, // 10 + INFINI_DTYPE_F16, // 12 + INFINI_DTYPE_F32, // 13 + INFINI_DTYPE_F64, // 14 + INFINI_DTYPE_BF16 // 19 + ); + + CHECK_SAME_SHAPE(out_desc->shape(), x_desc->shape()); + + CREATE_ELEMENTWISE_BANG_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) - auto handle_bang = reinterpret_cast(handle); - *desc_ptr = new Descriptor( - storageSpanBytes(output_desc), - new Opaque{handle_bang->internal()}, - handle->device, - handle->device_id); return INFINI_STATUS_SUCCESS; } @@ -106,18 +55,51 @@ infiniStatus_t Descriptor::calculate( size_t workspace_size, void *output, std::vector inputs, - void *stream) const { - (void)workspace; - (void)workspace_size; - (void)inputs; - if (_storage_size == 0) { - return INFINI_STATUS_SUCCESS; + void *queue) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } - auto queue = reinterpret_cast(stream); - int core_per_cluster = _opaque->internal->getCorePerCluster(); - int cluster_count = _opaque->internal->getClusterCount(); - return launchZeros(core_per_cluster, cluster_count, queue, output, _storage_size); + switch (_dtype) { + case INFINI_DTYPE_BYTE: // 1 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_BOOL: // 2 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_I8: // 3 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_I16: // 4 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_I32: // 5 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_I64: // 6 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_U8: // 7 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_U16: // 8 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_U32: // 9 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_U64: // 10 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_F8: // 11 + return INFINI_STATUS_NOT_IMPLEMENTED; + case INFINI_DTYPE_F16: // 12 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_F32: // 13 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_F64: // 14 + return _device_info->calculate(_info, workspace, output, inputs, queue); + case INFINI_DTYPE_C16: // 15 + case INFINI_DTYPE_C32: // 16 + case INFINI_DTYPE_C64: // 17 + case INFINI_DTYPE_C128: // 18 + return INFINI_STATUS_NOT_IMPLEMENTED; + case INFINI_DTYPE_BF16: // 19 + return _device_info->calculate(_info, workspace, output, inputs, queue); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } } } // namespace op::zeros::bang diff --git a/src/infiniop/ops/zeros/bang/zeros_bang_internal.mlu b/src/infiniop/ops/zeros/bang/zeros_bang_internal.mlu new file mode 100644 index 000000000..f3890d339 --- /dev/null +++ b/src/infiniop/ops/zeros/bang/zeros_bang_internal.mlu @@ -0,0 +1,39 @@ +#ifndef __ZEROS_BANG_INTERNAL_H__ +#define __ZEROS_BANG_INTERNAL_H__ + +#include "../../../elementwise/bang/elementwise_bang_kernel.h" + +typedef struct ZerosOp { +public: + static constexpr size_t num_inputs = 1; + template + __mlu_device__ void operator()(T *out, const T *input, const T *unused, size_t num_elements) const { + (void)input; + (void)unused; + if constexpr (std::is_same_v) { + for (size_t i = 0; i < num_elements; ++i) { + out[i] = 0.0; + } + } else { + __bang_write_value(out, num_elements, static_cast(0)); + } + } +} ZerosOp; + +LAUNCH_ELEMENTWISE_KERNEL_IMPL(Zeros, ZerosOp) + +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, uint8_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, bool) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, int8_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, int16_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, int32_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, int64_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, uint16_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, uint32_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, uint64_t) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, half) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, float) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, double) +LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(Zeros, bfloat16_t) + +#endif // __ZEROS_BANG_INTERNAL_H__