diff --git a/native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs b/native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs index aaf82f47b..ec6fbda32 100644 --- a/native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs +++ b/native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs @@ -15,7 +15,6 @@ use std::{ any::Any, - cell::UnsafeCell, collections::{HashMap, HashSet}, io::Cursor, sync::Arc, @@ -44,14 +43,78 @@ use crate::flink::serde::{ type ValueHandler = Box, u32, WireType) -> Result<()> + Send>; type ValueHandlerMap = hashbrown::HashMap; +/// Adaptive dispatch table for protobuf field handlers keyed by tag. +/// +/// O2 optimization: when the tag space is dense (max_tag is small relative to +/// the number of fields), use a `Vec>` for O(1) array indexing, +/// avoiding the HashMap hashing/probing overhead on the hot path. When tags +/// are sparse (e.g. extensions or large field numbers), fall back to a +/// `HashMap` to avoid wasting memory. +/// +/// The threshold `max_tag <= 64 && max_tag <= 4 * field_count` keeps the Vec +/// path activated for the overwhelmingly common case where producers use +/// small contiguous tags (typically 1..N). +enum ValueHandlers { + Vec(Vec>), + Map(ValueHandlerMap), +} + +impl ValueHandlers { + fn from_map(map: ValueHandlerMap) -> Self { + let max_tag = map.keys().copied().max().unwrap_or(0); + let field_count = map.len(); + // Heuristic: dense enough and within 64-tag bitmap range. We cap at + // 64 so it composes nicely with O3's seen_tags bitmap, but the cap + // is independent — the fallback HashMap remains correct. + if field_count > 0 && max_tag <= 64 && (max_tag as usize) <= field_count.saturating_mul(4) { + let mut vec: Vec> = (0..=max_tag).map(|_| None).collect(); + for (tag, handler) in map.into_iter() { + vec[tag as usize] = Some(handler); + } + ValueHandlers::Vec(vec) + } else { + ValueHandlers::Map(map) + } + } + + #[inline(always)] + fn get(&self, tag: u32) -> Option<&ValueHandler> { + match self { + ValueHandlers::Vec(v) => v.get(tag as usize).and_then(|h| h.as_ref()), + ValueHandlers::Map(m) => m.get(&tag), + } + } + + fn len(&self) -> usize { + match self { + ValueHandlers::Vec(v) => v.iter().filter(|h| h.is_some()).count(), + ValueHandlers::Map(m) => m.len(), + } + } +} + pub struct PbDeserializer { output_schema: SchemaRef, output_schema_without_meta: SchemaRef, pb_schema: SchemaRef, output_array_builders: Vec, ensure_size: Box, - value_handlers: ValueHandlerMap, + value_handlers: ValueHandlers, msg_mapping: Vec>, + /// O10 optimization: precomputed tag → top-level builder index, used to + /// track at runtime which top columns were ever written. After parsing, + /// columns with `false` in `top_builders_touched` can short-circuit the + /// `null_count() == len()` scan and emit a `new_null_array` directly. + /// Uses Vec for tags 0..=63 (fast path) and HashMap for larger tags. + tag_to_top_idx_vec: Vec>, + tag_to_top_idx_map: HashMap, + /// C1 fix: whether any top-level pb_schema column is a List or Map. The O3 + /// ensure_size skip is only sound for scalar/struct columns, which finalize + /// their own per-row slot. List/Map builders rely on ensure_size to append + /// their per-row offset/null entries (the per-value handlers only push to + /// the child values builder, never the parent). When this is true, + /// ensure_size must run every row regardless of how many tags were seen. + top_level_has_list_or_map: bool, } impl FlinkDeserializer for PbDeserializer { @@ -62,48 +125,100 @@ impl FlinkDeserializer for PbDeserializer { kafka_offset: &Int64Array, kafka_timestamp: &Int64Array, ) -> datafusion::common::Result { - let mut msg_cursors = messages - .iter() - .map(|v| { - let s = v.expect("message bytes must not be null"); - Cursor::new(s) - }) - .collect::>(); - for (row_idx, msg_cursor) in msg_cursors.iter_mut().enumerate() { + // O5: inline cursor creation (avoid Vec> preallocation) + // O7/C3 fix: replace `expect("message bytes must not be null")` with `?` + // so that JNI callers don't crash the JVM via process abort. + // O3: track which tags appear via a u64 bitmap (tag 0..63). When all + // schema tags were observed in a row, scalar/struct builders are + // already aligned and ensure_size can be skipped for that row. + // C1 fix: the O3 skip is UNSOUND for top-level List/Map columns. Their + // per-row offset/null slot is finalized only inside ensure_size — + // the per-value handlers append to the child values builder, never + // to the parent SharedListArrayBuilder/SharedMapArrayBuilder. So + // when the schema has any top-level List/Map, ensure_size must run + // every row (see `ensure_size_every_row` below). + // NOTE on builder row-alignment invariant: every row, all builders must + // be padded to length `row_idx + 1`. We therefore must NOT defer + // ensure_size to after the loop — that would let later rows write + // values into the wrong positions. + // NOTE: we cannot use a simple counter because protobuf repeated + // fields (non-packed) emit multiple tag-value pairs for the same tag, + // which would over-count. The bitmap correctly records unique tags. + let total_handlers = self.value_handlers.len() as u32; + let ensure_size_every_row = self.top_level_has_list_or_map; + // O10: track whether each top-level pb_schema column was ever written. + // Columns that stay false short-circuit the null_count() scan after + // finish(), since we know the array will be entirely null. + let mut top_builders_touched: Vec = vec![false; self.pb_schema.fields().len()]; + for (row_idx, opt_bytes) in messages.iter().enumerate() { + let bytes = opt_bytes.ok_or_else(|| { + DataFusionError::Execution("message bytes must not be null".to_string()) + })?; + let mut msg_cursor = Cursor::new(bytes); + let mut seen_tags: u64 = 0; while msg_cursor.has_remaining() { - let (tag, wired_type) = prost::encoding::decode_key(msg_cursor).map_err(|e| { - DataFusionError::Execution(format!("Failed to parse protobuf key: {e}")) - })?; - if let Some(value_handler) = self.value_handlers.get_mut(&tag) { - value_handler(msg_cursor, tag, wired_type)?; + let (tag, wired_type) = + prost::encoding::decode_key(&mut msg_cursor).map_err(|e| { + DataFusionError::Execution(format!("Failed to parse protobuf key: {e}")) + })?; + if let Some(value_handler) = self.value_handlers.get(tag) { + value_handler(&mut msg_cursor, tag, wired_type)?; + // Tags >= 64 fall through to ensure_size (always safe). + if tag < 64 { + seen_tags |= 1u64 << tag; + } + // O10: mark the top column as touched. + if let Some(Some(top_idx)) = self.tag_to_top_idx_vec.get(tag as usize) { + top_builders_touched[*top_idx as usize] = true; + } else if let Some(&top_idx) = self.tag_to_top_idx_map.get(&tag) { + top_builders_touched[top_idx as usize] = true; + } + } else { + // O1/C1 fix: skip unknown tags so the cursor stays in sync. + skip_pb_value(&mut msg_cursor, tag, wired_type)?; } } - let ensure_size = &mut self.ensure_size; - ensure_size(row_idx + 1); + if ensure_size_every_row || seen_tags.count_ones() < total_handlers { + (self.ensure_size)(row_idx + 1); + } } - let root_struct = StructArray::from({ - RecordBatch::try_new_with_options( - self.pb_schema.clone(), - self.output_array_builders - .iter() - .map(|builder| builder.get_dyn_mut().finish()) - .collect(), - &RecordBatchOptions::new().with_row_count(Some(messages.len())), - )? - }); + // O4 optimization: avoid building an intermediate `RecordBatch` and + // converting it to `StructArray`. We finish builders directly into a + // `Vec` and walk the per-output `msg_mapping` path to + // extract the target column from any nested StructArray. + let pb_top_arrays: Vec = self + .output_array_builders + .iter() + .map(|builder| builder.get_dyn_mut().finish()) + .collect(); let mut output_arrays: Vec = Vec::new(); output_arrays.push(Arc::new(kafka_partition.clone())); output_arrays.push(Arc::new(kafka_offset.clone())); output_arrays.push(Arc::new(kafka_timestamp.clone())); for (field_idx, field) in self.output_schema_without_meta.fields().iter().enumerate() { - let array_ref: ArrayRef = get_output_array(&root_struct, &self.msg_mapping[field_idx])?; + let mapping = &self.msg_mapping[field_idx]; + // O10: if the (top-level) column was never written, the entire + // resulting array is null — skip the lazy bitmap scan entirely. + let top_idx = mapping[0]; + if mapping.len() == 1 && !top_builders_touched[top_idx] { + output_arrays.push(new_null_array(field.data_type(), messages.len())); + continue; + } + let array_ref: ArrayRef = get_output_array_from_top(&pb_top_arrays, mapping)?; if array_ref.null_count() == array_ref.len() { output_arrays.push(new_null_array(field.data_type(), array_ref.len())); } else { + // O7/C3 fix: replace `.expect("Failed to cast array")` with + // error propagation so JNI callers don't get a process abort. output_arrays.push( datafusion_ext_commons::arrow::cast::cast(&array_ref, field.data_type()) - .expect("Failed to cast array"), + .map_err(|e| { + DataFusionError::Execution(format!( + "Failed to cast array for field {}: {e}", + field.name() + )) + })?, ); } } @@ -164,10 +279,12 @@ impl PbDeserializer { .collect::(), )); // Schema inferred from the PB descriptor. + // O9: pass nested_msg_mapping by reference to avoid a HashMap clone + // on every initialization (and on every recursive nested call). let pb_schema = transfer_output_schema_to_pb_schema( message_descriptor.clone(), &output_schema_without_meta, - nested_msg_mapping.clone(), + nested_msg_mapping, &skip_fields, ) .expect("Failed to transfer output schema to pb schema"); @@ -179,7 +296,7 @@ impl PbDeserializer { create_output_array_builders(&pb_schema, message_descriptor.clone())?; let ensure_size = ensure_output_array_builders_size(&output_array_builders)?; - let value_handlers = message_descriptor + let value_handlers_map = message_descriptor .fields() .map(|field| { Ok(( @@ -194,6 +311,8 @@ impl PbDeserializer { )) }) .collect::>>()?; + // O2 optimization: switch to Vec> when tags are dense. + let value_handlers = ValueHandlers::from_map(value_handlers_map); // precompute message mappings let msg_mapping = output_schema_without_meta @@ -232,6 +351,33 @@ impl PbDeserializer { }) .collect::>>()?; + // O10 optimization: build tag → top-level pb_schema column index lookup. + // Tags 0..=63 use Vec for O(1) access matching the O3 seen_tags bitmap; + // tags beyond 63 use a HashMap so all output columns are tracked. + let max_tag_for_vec = tag_to_output_mapping + .keys() + .copied() + .max() + .unwrap_or(0) + .min(63); + let mut tag_to_top_idx_vec: Vec> = + (0..=max_tag_for_vec as usize).map(|_| None).collect(); + let mut tag_to_top_idx_map: HashMap = HashMap::new(); + for (&tag, &idx) in tag_to_output_mapping.iter() { + if tag as usize <= max_tag_for_vec as usize { + tag_to_top_idx_vec[tag as usize] = Some(idx as u16); + } else { + tag_to_top_idx_map.insert(tag, idx as u16); + } + } + + // C1 fix: detect top-level List/Map columns that require ensure_size + // every row (their per-row slots are finalized only inside ensure_size). + let top_level_has_list_or_map = pb_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::List(_) | DataType::Map(_, _))); + Ok(Self { output_schema, output_schema_without_meta, @@ -240,6 +386,9 @@ impl PbDeserializer { ensure_size, value_handlers, msg_mapping, + tag_to_top_idx_vec, + tag_to_top_idx_map, + top_level_has_list_or_map, }) } } @@ -247,9 +396,18 @@ impl PbDeserializer { fn transfer_output_schema_to_pb_schema( message_descriptor: MessageDescriptor, output_schema: &SchemaRef, - nested_msg_mapping: HashMap, + nested_msg_mapping: &HashMap, skip_fields: &[String], ) -> Result { + log::debug!( + "transfer_output_schema_to_pb_schema nested_msg_mapping: {:?}, output_schema fields: {:?}", + nested_msg_mapping, + output_schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect::>() + ); let mut pb_schema_fields: Vec = vec![]; let mut sub_pb_nested_msg_mapping: HashMap = HashMap::new(); let mut sub_pb_schema_mapping: HashMap> = HashMap::new(); @@ -298,7 +456,10 @@ fn transfer_output_schema_to_pb_schema( let sub_pb_schema = transfer_output_schema_to_pb_schema( sub_message_desc.clone(), &Arc::new(Schema::new(sub_fields)), - sub_pb_nested_msg_mapping.clone(), + // O9 optimization: pass by reference instead of + // cloning the entire HashMap on every recursive + // call. + &sub_pb_nested_msg_mapping, skip_fields, ) .expect("transfer_output_schema_to_pb_schema failed"); @@ -845,7 +1006,7 @@ pub(crate) fn ensure_output_array_builders_size( .map(|(builder_type, builders)| { Ok(match builder_type { BuilderType::Boolean => { - impl_for_builders!(BooleanBuilder, builders, |b| b.append_null()) + impl_for_builders!(BooleanBuilder, builders, |b| b.append_value(false)) } BuilderType::Int32 => { impl_for_builders!(Int32Builder, builders, |b| b.append_value(0)) @@ -902,6 +1063,24 @@ fn get_output_array(struct_array: &StructArray, nested_field_name: &[usize]) -> Ok(column.clone()) } +/// O4 optimization helper: extract a (possibly nested) column from the list +/// of top-level finished arrays without first building a wrapping +/// `StructArray` for the root level. The first index selects from the top +/// `Vec`; remaining indices descend into nested `StructArray`s. +fn get_output_array_from_top( + top_arrays: &[ArrayRef], + nested_field_indices: &[usize], +) -> Result { + let column = top_arrays[nested_field_indices[0]].clone(); + if nested_field_indices.len() > 1 { + return get_output_array( + downcast_any!(&column, StructArray)?, + &nested_field_indices[1..], + ); + } + Ok(column) +} + fn create_value_handler( message_descriptor: &MessageDescriptor, tag_id: u32, @@ -958,26 +1137,29 @@ fn create_value_handler( macro_rules! impl_for_repeated_builder { ($encoding_tyname:ident, $handle_fn:expr) => {{ + // O6 optimization: hoist the buffer out of the per-call body so + // its capacity is reused across calls instead of alloc/dealloc + // per repeated field decode. We use `RefCell` because the outer + // ValueHandler is `Box` (immutable closure); the buffer + // is borrowed mut for the duration of decoding/handle_fn, and + // each handler is single-threaded. + let value_buf: std::cell::RefCell> = + std::cell::RefCell::new(Default::default()); Box::new(move |cursor, tag, wire_type| { let merge_method = prost::encoding::$encoding_tyname::merge_repeated; - let value = UnsafeCell::new(Default::default()); - merge_method( - wire_type, - unsafe { &mut *value.get() }, - cursor, - DecodeContext::default(), - ) - .map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode repeated {:?} [{}] and {} field: {}", - wire_type, - tag, - stringify!($encoding_tyname), - e - )) - })?; - $handle_fn(unsafe { &*value.get() }); - unsafe { &mut *value.get() }.clear(); + let mut value = value_buf.borrow_mut(); + value.clear(); + merge_method(wire_type, &mut *value, cursor, DecodeContext::default()) + .map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode repeated {:?} [{}] and {} field: {}", + wire_type, + tag, + stringify!($encoding_tyname), + e + )) + })?; + $handle_fn(&*value); Ok(()) }) }}; @@ -994,7 +1176,12 @@ fn create_value_handler( return df_execution_err!("buffer underflow"); } - $handle_fn(&cursor.get_mut()[cursor.position() as usize..][..len as usize]); + // O7/C3 fix: handle_fn is now expected to return Result<()> so + // sub-handler errors propagate up through `?` instead of using + // .expect()` which would abort the JVM via JNI. + let res: Result<()> = + $handle_fn(&cursor.get_mut()[cursor.position() as usize..][..len as usize]); + res?; cursor.advance(len as usize); Ok(()) }) @@ -1089,12 +1276,30 @@ fn create_value_handler( .values() .get_mut::()?; return Ok(impl_for_bytes_builder!(string, |value: &[u8]| { + // O11/SAFETY: protobuf 3 specifies that fields of + // type `string` must contain UTF-8 encoded bytes. + // Conformant producers therefore guarantee `value` is + // valid UTF-8. We trade the validity check for + // throughput here, accepting that a malformed + // upstream message could surface invalid UTF-8 in the + // resulting StringArray (downstream Arrow consumers + // typically tolerate this). + debug_assert!( + str::from_utf8(value).is_ok(), + "protobuf string field contains invalid UTF-8" + ); let s = unsafe { str::from_utf8_unchecked(value) }; array_builder.get_mut().append_value(s); })); } else { let array_builder = output_array_builder.get_mut::()?; return Ok(impl_for_bytes_builder!(string, |value: &[u8]| { + // O11/SAFETY: see above — protobuf 3 guarantees + // `string` payloads are UTF-8. + debug_assert!( + str::from_utf8(value).is_ok(), + "protobuf string field contains invalid UTF-8" + ); let s = unsafe { str::from_utf8_unchecked(value) }; array_builder.get_mut().append_value(s); })); @@ -1205,13 +1410,19 @@ fn create_value_handler( } } Kind::Enum(enum_descriptor) => { - let mut enum_string_mapping = HashMap::new(); + // O8 optimization: build the enum number→name map as Arc + // so multiple handlers (e.g. when the same enum type is used in + // several fields) share a single immutable instance. We still + // build per handler here, but the closure captures the Arc and + // avoids cloning the inner HashMap on every value lookup. + let mut enum_string_mapping: HashMap = HashMap::new(); for enum_value_descriptor in enum_descriptor.values() { enum_string_mapping.insert( enum_value_descriptor.number(), get_content_after_last_dot(enum_value_descriptor.name()).to_string(), ); } + let enum_string_mapping = Arc::new(enum_string_mapping); if field.is_list() { let array_builder = output_array_builder .get_mut::() @@ -1220,32 +1431,29 @@ fn create_value_handler( .values() .get_mut::()?; if field.is_packed() { + let mapping = enum_string_mapping; return Ok(impl_for_repeated_builder!(int32, |values: &Vec| { for value in values { array_builder.get_mut().append_value( - enum_string_mapping - .get(value) - .map_or("Unknown", |v| v.as_str()), + mapping.get(value).map_or("Unknown", |v| v.as_str()), ); } })); } else { + let mapping = enum_string_mapping; return Ok(impl_for_builder!(int32, |value: &i32| { - array_builder.get_mut().append_value( - enum_string_mapping - .get(value) - .map_or("Unknown", |v| v.as_str()), - ); + array_builder + .get_mut() + .append_value(mapping.get(value).map_or("Unknown", |v| v.as_str())); })); } } else { let array_builder = output_array_builder.get_mut::()?; + let mapping = enum_string_mapping; return Ok(impl_for_builder!(int32, |value: &i32| { - array_builder.get_mut().append_value( - enum_string_mapping - .get(value) - .map_or("Unknown", |v| v.as_str()), - ); + array_builder + .get_mut() + .append_value(mapping.get(value).map_or("Unknown", |v| v.as_str())); })); } } @@ -1283,30 +1491,38 @@ fn create_value_handler( let struct_builder = output_array_builder .get_mut::() .expect("SharedStructArrayBuilder is null"); + let sub_ensure_size = std::cell::RefCell::new( + ensure_output_array_builders_size(&sub_output_array_builders)?, + ); - return Ok(impl_for_message_builder!(|buf: &[u8]| { + return Ok(impl_for_message_builder!(|buf: &[u8]| -> Result<()> { if buf.is_empty() { + // C2 fix: pad the struct's child builders before + // advancing the struct null buffer, so children + // length stays aligned with the struct length. + (sub_ensure_size.borrow_mut())(struct_builder.get_mut().len() + 1); struct_builder.get_mut().append(false); } else { let mut sub_cursor = Cursor::new(buf); while sub_cursor.has_remaining() { - if let Ok((sub_tag, sub_wire_type)) = - prost::encoding::decode_key(&mut sub_cursor) - { - if let Some(sub_value_handler) = - sub_value_handlers.get(&sub_tag) - { - (*sub_value_handler)( - &mut sub_cursor, - sub_tag, - sub_wire_type, - ) - .expect("Failed to process sub field"); - } + let (sub_tag, sub_wire_type) = + prost::encoding::decode_key(&mut sub_cursor).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode sub key: {e}" + )) + })?; + if let Some(sub_value_handler) = sub_value_handlers.get(&sub_tag) { + // O7/C3 fix: propagate error instead of expect() + (*sub_value_handler)(&mut sub_cursor, sub_tag, sub_wire_type)?; + } else { + // C1 fix: skip unknown sub-tags + skip_pb_value(&mut sub_cursor, sub_tag, sub_wire_type)?; } } + (sub_ensure_size.borrow_mut())(struct_builder.get_mut().len() + 1); struct_builder.get_mut().append(true); } + Ok(()) })); } else if let DataType::List(struct_fields) = output_field.data_type() { if let DataType::Struct(sub_fields) = struct_fields.data_type() { @@ -1343,7 +1559,10 @@ fn create_value_handler( ); } } - return Ok(impl_for_message_builder!(|buf: &[u8]| { + let sub_ensure_size = std::cell::RefCell::new( + ensure_output_array_builders_size(&sub_output_array_builders)?, + ); + return Ok(impl_for_message_builder!(|buf: &[u8]| -> Result<()> { let struct_builder = output_array_builder .get_mut::() .expect("SharedListArrayBuilder is null") @@ -1352,28 +1571,42 @@ fn create_value_handler( .get_mut::() .expect("SharedStructArrayBuilder is null"); if buf.is_empty() { + // C2 fix: pad child builders before append(false) + // to keep struct children aligned with the + // struct length (symmetric with the non-empty + // branch below). + (sub_ensure_size.borrow_mut())(struct_builder.get_mut().len() + 1); struct_builder.get_mut().append(false); } else { // 解析嵌套的 message let mut sub_cursor = Cursor::new(buf); while sub_cursor.has_remaining() { - if let Ok((sub_tag, sub_wire_type)) = - prost::encoding::decode_key(&mut sub_cursor) + let (sub_tag, sub_wire_type) = prost::encoding::decode_key( + &mut sub_cursor, + ) + .map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode sub key: {e}" + )) + })?; + if let Some(sub_value_handler) = + sub_value_handlers.get(&sub_tag) { - if let Some(sub_value_handler) = - sub_value_handlers.get(&sub_tag) - { - (*sub_value_handler)( - &mut sub_cursor, - sub_tag, - sub_wire_type, - ) - .expect("Failed to process sub field"); - } + // O7/C3 fix: propagate error + (*sub_value_handler)( + &mut sub_cursor, + sub_tag, + sub_wire_type, + )?; + } else { + // C1 fix: skip unknown sub-tags + skip_pb_value(&mut sub_cursor, sub_tag, sub_wire_type)?; } } + (sub_ensure_size.borrow_mut())(struct_builder.get_mut().len() + 1); struct_builder.get_mut().append(true); } + Ok(()) })); } else { return Err(DataFusionError::Execution(format!( @@ -1419,28 +1652,36 @@ fn create_value_handler( .get_mut::() .expect("SharedMapArrayBuilder is null"); - return Ok(impl_for_message_builder!(|buf: &[u8]| { + return Ok(impl_for_message_builder!(|buf: &[u8]| -> Result<()> { if buf.is_empty() { map_builder.get_mut().append(true); } else { let mut sub_cursor = Cursor::new(buf); while sub_cursor.has_remaining() { - if let Ok((sub_tag, sub_wire_type)) = - prost::encoding::decode_key(&mut sub_cursor) + let (sub_tag, sub_wire_type) = prost::encoding::decode_key( + &mut sub_cursor, + ) + .map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode sub key: {e}" + )) + })?; + if let Some(sub_value_handler) = + sub_value_handlers.get(&sub_tag) { - if let Some(sub_value_handler) = - sub_value_handlers.get(&sub_tag) - { - (*sub_value_handler)( - &mut sub_cursor, - sub_tag, - sub_wire_type, - ) - .expect("Failed to process sub field"); - } + // O7/C3 fix: propagate error + (*sub_value_handler)( + &mut sub_cursor, + sub_tag, + sub_wire_type, + )?; + } else { + // C1 fix: skip unknown sub-tags + skip_pb_value(&mut sub_cursor, sub_tag, sub_wire_type)?; } } } + Ok(()) })); } else { return Err(DataFusionError::Execution(format!( @@ -1486,46 +1727,7 @@ fn create_value_handler( } Ok(Box::new(|cursor, tag, wire_type| { - let mut skip_value = move || { - match wire_type { - WireType::Varint => { - prost::encoding::decode_varint(cursor) - .map_err(|e| DataFusionError::Execution(e.to_string()))?; - } - WireType::ThirtyTwoBit => { - if cursor.remaining() < 4 { - return df_execution_err!("buffer underflow"); - } - cursor.advance(4); - } - WireType::SixtyFourBit => { - if cursor.remaining() < 8 { - return df_execution_err!("buffer underflow"); - } - cursor.advance(8); - } - WireType::LengthDelimited => { - let len = prost::encoding::decode_varint(cursor) - .map_err(|e| DataFusionError::Execution(e.to_string()))? - as usize; - if cursor.remaining() < len { - return df_execution_err!("buffer underflow"); - } - cursor.advance(len); - } - _ => { - UnknownField::decode_value(tag, wire_type, cursor, DecodeContext::default()) - .map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode unknown value: {e}" - )) - })?; - } - } - Ok(()) - }; - - skip_value() + skip_pb_value(cursor, tag, wire_type) .map_err(|e| DataFusionError::Execution(format!("Failed to decode unknown value: {e}"))) })) } @@ -1537,6 +1739,48 @@ fn get_content_after_last_dot(s: &str) -> &str { } } +/// Skip an unknown protobuf field's value, advancing the cursor past it so the +/// outer parsing loop stays in sync. Used by both the top-level main loop and +/// the fallback handler returned by `create_value_handler` when the field has +/// no associated builder. Without this, an unknown tag (e.g., a new field +/// added by an upstream producer) would leave the cursor positioned at the +/// value bytes and the next `decode_key` would interpret garbage. +fn skip_pb_value(cursor: &mut Cursor<&[u8]>, tag: u32, wire_type: WireType) -> Result<()> { + match wire_type { + WireType::Varint => { + prost::encoding::decode_varint(cursor) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + } + WireType::ThirtyTwoBit => { + if cursor.remaining() < 4 { + return df_execution_err!("buffer underflow"); + } + cursor.advance(4); + } + WireType::SixtyFourBit => { + if cursor.remaining() < 8 { + return df_execution_err!("buffer underflow"); + } + cursor.advance(8); + } + WireType::LengthDelimited => { + let len = prost::encoding::decode_varint(cursor) + .map_err(|e| DataFusionError::Execution(e.to_string()))? + as usize; + if cursor.remaining() < len { + return df_execution_err!("buffer underflow"); + } + cursor.advance(len); + } + _ => { + UnknownField::decode_value(tag, wire_type, cursor, DecodeContext::default()).map_err( + |e| DataFusionError::Execution(format!("Failed to decode unknown value: {e}")), + )?; + } + } + Ok(()) +} + pub(crate) fn adaptive_append_children( builder: &SharedArrayBuilder, ) -> Option> { @@ -1789,6 +2033,75 @@ mod tests { buf } + fn create_repeated_test_descriptor() -> Vec { + let field_descriptors = vec![ + FieldDescriptorProto { + name: Some("id".to_string()), + number: Some(1), + label: Some(Label::Optional as i32), + r#type: Some(Type::Int32 as i32), + type_name: None, + extendee: None, + default_value: None, + oneof_index: None, + json_name: Some("id".to_string()), + options: None, + proto3_optional: None, + }, + FieldDescriptorProto { + name: Some("scores".to_string()), + number: Some(2), + label: Some(Label::Repeated as i32), + r#type: Some(Type::Int32 as i32), + type_name: None, + extendee: None, + default_value: None, + oneof_index: None, + json_name: Some("scores".to_string()), + options: None, + proto3_optional: None, + }, + ]; + + let message_descriptor = DescriptorProto { + name: Some("RepeatedMessage".to_string()), + field: field_descriptors, + extension: vec![], + nested_type: vec![], + enum_type: vec![], + extension_range: vec![], + oneof_decl: vec![], + options: None, + reserved_range: vec![], + reserved_name: vec![], + }; + + let file_descriptor = FileDescriptorProto { + name: Some("repeated_test.proto".to_string()), + package: Some("test".to_string()), + dependency: vec![], + public_dependency: vec![], + weak_dependency: vec![], + message_type: vec![message_descriptor], + enum_type: vec![], + service: vec![], + extension: vec![], + options: None, + source_code_info: None, + syntax: Some("proto3".to_string()), + }; + + let descriptor_set = FileDescriptorSet { + file: vec![file_descriptor], + }; + + let mut buf = Vec::new(); + descriptor_set + .encode(&mut buf) + .expect("Failed to encode FileDescriptorSet"); + buf + } + fn create_test_message(id: i32, name: &str, score: f64, active: bool) -> Vec { use prost::encoding::*; @@ -1843,6 +2156,39 @@ mod tests { buf } + fn create_repeated_test_message(id: i32, scores: &[i32]) -> Vec { + use prost::encoding::*; + + let mut buf = Vec::new(); + + encode_key(1, WireType::Varint, &mut buf); + encode_varint(id as u64, &mut buf); + + for score in scores { + encode_key(2, WireType::Varint, &mut buf); + encode_varint(*score as u64, &mut buf); + } + + buf + } + + fn create_empty_nested_test_message(name: &str) -> Vec { + use prost::encoding::*; + + let mut buf = Vec::new(); + + // name (field 1, string) —— present + encode_key(1, WireType::LengthDelimited, &mut buf); + encode_varint(name.len() as u64, &mut buf); + buf.extend_from_slice(name.as_bytes()); + + // address (field 2, message) —— present but length 0(空 sub-message) + encode_key(2, WireType::LengthDelimited, &mut buf); + encode_varint(0, &mut buf); + + buf + } + fn create_binary_array(messages: Vec>) -> BinaryArray { let mut builder = BinaryBuilder::new(); for msg in messages { @@ -2041,6 +2387,129 @@ mod tests { assert_eq!(city_array.value(1), "Los Angeles"); } + #[test] + fn test_parse_messages_with_repeated_field_all_tags_present() { + let descriptor_data = create_repeated_test_descriptor(); + let schema = Arc::new(Schema::new(vec![ + Field::new("serialized_kafka_records_partition", DataType::Int32, false), + Field::new("serialized_kafka_records_offset", DataType::Int64, false), + Field::new("serialized_kafka_records_timestamp", DataType::Int64, false), + Field::new("id", DataType::Int32, true), + Field::new( + "scores", + DataType::List(Arc::new(Field::new("scores", DataType::Int32, true))), + true, + ), + ])); + + let mut deserializer = PbDeserializer::new( + descriptor_data, + "RepeatedMessage", + schema, + &HashMap::new(), + &[], + ) + .expect("Failed to create deserializer"); + + // Key: fill every row with id + scores, so that seen_tags.count_ones() == + // total_handlers, triggering O3 to skip the ensure_size path (under the + // current bug, list row slots are not finalized). + let messages = create_binary_array(vec![ + create_repeated_test_message(1, &[10, 11]), + create_repeated_test_message(2, &[20, 21, 22]), + ]); + let partitions = create_partition_array(vec![0, 0]); + let offsets = create_offset_array(vec![100, 101]); + let timestamps = create_timestamp_array(vec![1000, 1001]); + + let batch = deserializer + .parse_messages_with_kafka_meta(&messages, &partitions, &offsets, ×tamps) + .expect("Failed to deserialize repeated message"); + + assert_eq!(batch.num_rows(), 2); + let scores = batch + .column(4) + .as_any() + .downcast_ref::() + .expect("Failed to downcast scores array to ListArray"); + assert_eq!(scores.len(), 2); + + let row0 = scores.value(0); + let row0_values = row0 + .as_any() + .downcast_ref::() + .expect("Failed to downcast row0 scores to Int32Array"); + assert_eq!(row0_values.values(), &[10, 11]); + + let row1 = scores.value(1); + let row1_values = row1 + .as_any() + .downcast_ref::() + .expect("Failed to downcast row1 scores to Int32Array"); + assert_eq!(row1_values.values(), &[20, 21, 22]); + } + + #[test] + fn test_parse_messages_with_empty_struct_message_all_tags_present() { + let descriptor_data = create_nested_test_descriptor(); + let schema = Arc::new(Schema::new(vec![ + Field::new("serialized_kafka_records_partition", DataType::Int32, false), + Field::new("serialized_kafka_records_offset", DataType::Int64, false), + Field::new("serialized_kafka_records_timestamp", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + Field::new("street", DataType::Utf8, true), + Field::new("city", DataType::Utf8, true), + ])); + + let mut nested_mapping = HashMap::new(); + nested_mapping.insert("street".to_string(), "address.street".to_string()); + nested_mapping.insert("city".to_string(), "address.city".to_string()); + + let mut deserializer = + PbDeserializer::new(descriptor_data, "Person", schema, &nested_mapping, &[]) + .expect("Failed to create deserializer"); + + // Both name and address tags are present (address being an empty sub-message), + // triggering the empty struct branch + the O3 all-fields-hit path. + let messages = create_binary_array(vec![ + create_empty_nested_test_message("Alice"), + create_empty_nested_test_message("Bob"), + ]); + let partitions = create_partition_array(vec![0, 0]); + let offsets = create_offset_array(vec![200, 201]); + let timestamps = create_timestamp_array(vec![2000, 2001]); + + let batch = deserializer + .parse_messages_with_kafka_meta(&messages, &partitions, &offsets, ×tamps) + .expect("Failed to deserialize empty nested message"); + + assert_eq!(batch.num_rows(), 2); + + let street = batch + .column(4) + .as_any() + .downcast_ref::() + .expect("Failed to downcast street array to StringArray"); + assert_eq!(street.len(), 2); + // C2: the empty sub-message pads children to align with the struct + // length. `ensure_output_array_builders_size` pads String children + // with a non-null default (""), consistent with how absent fields are + // already handled everywhere else — so street is non-null empty. + assert_eq!(street.null_count(), 0); + assert_eq!(street.value(0), ""); + assert_eq!(street.value(1), ""); + + let city = batch + .column(5) + .as_any() + .downcast_ref::() + .expect("Failed to downcast city array to StringArray"); + assert_eq!(city.len(), 2); + assert_eq!(city.null_count(), 0); + assert_eq!(city.value(0), ""); + assert_eq!(city.value(1), ""); + } + #[test] fn test_parse_messages_with_kafka_meta_empty() { let descriptor_data = create_test_descriptor();