diff --git a/rust/chg/src/message.rs b/rust/chg/src/message.rs --- a/rust/chg/src/message.rs +++ b/rust/chg/src/message.rs @@ -5,7 +5,7 @@ //! Utility for parsing and building command-server messages. -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use std::error; use std::ffi::{OsStr, OsString}; use std::io; @@ -67,6 +67,45 @@ } } +// allocate large buffer as environment variables can be quite long +const INITIAL_PACKED_ENV_VARS_CAPACITY: usize = 4096; + +/// Packs environment variables of platform encoding into bytes. +/// +/// # Panics +/// +/// Panics if key or value contains `\0` character, or key contains '=' +/// character. +pub fn pack_env_vars_os(vars: I) -> Bytes +where + I: IntoIterator, + P: AsRef, +{ + let mut vars_iter = vars.into_iter(); + if let Some((k, v)) = vars_iter.next() { + let mut dst = BytesMut::with_capacity(INITIAL_PACKED_ENV_VARS_CAPACITY); + pack_env_into(&mut dst, k.as_ref(), v.as_ref()); + for (k, v) in vars_iter { + dst.reserve(1); + dst.put_u8(b'\0'); + pack_env_into(&mut dst, k.as_ref(), v.as_ref()); + } + dst.freeze() + } else { + Bytes::new() + } +} + +fn pack_env_into(dst: &mut BytesMut, k: &OsStr, v: &OsStr) { + assert!(!k.as_bytes().contains(&0), "key shouldn't contain NUL"); + assert!(!k.as_bytes().contains(&b'='), "key shouldn't contain '='"); + assert!(!v.as_bytes().contains(&0), "value shouldn't contain NUL"); + dst.reserve(k.as_bytes().len() + 1 + v.as_bytes().len()); + dst.put_slice(k.as_bytes()); + dst.put_u8(b'='); + dst.put_slice(v.as_bytes()); +} + fn decode_latin1(s: S) -> String where S: AsRef<[u8]>, @@ -85,6 +124,7 @@ mod tests { use super::*; use std::os::unix::ffi::OsStringExt; + use std::panic; #[test] fn parse_command_spec_good() { @@ -127,7 +167,66 @@ assert!(parse_command_spec(Bytes::from_static(b"paper\0less")).is_err()); } + #[test] + fn pack_env_vars_os_good() { + assert_eq!( + pack_env_vars_os(vec![] as Vec<(OsString, OsString)>), + Bytes::new() + ); + assert_eq!( + pack_env_vars_os(vec![os_string_pair_from(b"FOO", b"bar")]), + Bytes::from_static(b"FOO=bar") + ); + assert_eq!( + pack_env_vars_os(vec![ + os_string_pair_from(b"FOO", b""), + os_string_pair_from(b"BAR", b"baz") + ]), + Bytes::from_static(b"FOO=\0BAR=baz") + ); + } + + #[test] + fn pack_env_vars_os_large_key() { + let mut buf = vec![b'A'; INITIAL_PACKED_ENV_VARS_CAPACITY]; + let envs = vec![os_string_pair_from(&buf, b"")]; + buf.push(b'='); + assert_eq!(pack_env_vars_os(envs), Bytes::from(buf)); + } + + #[test] + fn pack_env_vars_os_large_value() { + let mut buf = vec![b'A', b'=']; + buf.resize(INITIAL_PACKED_ENV_VARS_CAPACITY + 1, b'a'); + let envs = vec![os_string_pair_from(&buf[..1], &buf[2..])]; + assert_eq!(pack_env_vars_os(envs), Bytes::from(buf)); + } + + #[test] + fn pack_env_vars_os_nul_eq() { + assert!(panic::catch_unwind(|| { + pack_env_vars_os(vec![os_string_pair_from(b"\0", b"")]) + }) + .is_err()); + assert!(panic::catch_unwind(|| { + pack_env_vars_os(vec![os_string_pair_from(b"FOO", b"\0bar")]) + }) + .is_err()); + assert!(panic::catch_unwind(|| { + pack_env_vars_os(vec![os_string_pair_from(b"FO=", b"bar")]) + }) + .is_err()); + assert_eq!( + pack_env_vars_os(vec![os_string_pair_from(b"FOO", b"=ba")]), + Bytes::from_static(b"FOO==ba") + ); + } + fn os_string_from(s: &[u8]) -> OsString { OsString::from_vec(s.to_vec()) } + + fn os_string_pair_from(k: &[u8], v: &[u8]) -> (OsString, OsString) { + (os_string_from(k), os_string_from(v)) + } }