diff --git a/src/uu/head/src/head.rs b/src/uu/head/src/head.rs index 896e0c923e1..f047060f387 100644 --- a/src/uu/head/src/head.rs +++ b/src/uu/head/src/head.rs @@ -6,7 +6,7 @@ // spell-checker:ignore (vars) memrchr use clap::ArgMatches; -use memchr::memrchr_iter; +use memchr::{memchr_iter, memrchr_iter}; use std::ffi::OsString; use std::fs::File; use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write}; @@ -31,7 +31,6 @@ mod parse; mod take; use take::copy_all_but_n_bytes; use take::copy_all_but_n_lines; -use take::take_lines; #[derive(Error, Debug)] enum HeadError { @@ -193,21 +192,56 @@ fn print_n_bytes(input: impl Read, n: u64) -> io::Result { Ok(bytes_written) } -fn print_n_lines(input: &mut impl io::BufRead, n: u64, separator: u8) -> io::Result { - // Read the first `n` lines from the `input` reader. - let mut reader = take_lines(input, n, separator); +enum HeadFileError { + Read(io::Error), + WriteStdout(io::Error), +} - // Write those bytes to `stdout`. +fn print_n_lines( + input: &mut impl io::BufRead, + n: u64, + separator: u8, +) -> Result { let stdout = io::stdout(); let stdout = stdout.lock(); let mut writer = BufWriter::with_capacity(BUF_SIZE, stdout); - let bytes_written = io::copy(&mut reader, &mut writer).map_err(wrap_in_stdout_error)?; + let mut bytes_written: u64 = 0; + let mut remaining = n; + while remaining > 0 { + let chunk = input.fill_buf().map_err(HeadFileError::Read)?; + + if chunk.is_empty() { + break; + } + + let mut take_len = chunk.len(); // default: take everything + let mut separators_seen: u64 = 0; + for separator_idx in memchr_iter(separator, chunk) { + separators_seen += 1; + if separators_seen == remaining { + take_len = separator_idx + 1; // include the separator itself + break; + } + } + remaining -= separators_seen; + + writer + .write_all(&chunk[..take_len]) + .map_err(wrap_in_stdout_error) + .map_err(HeadFileError::WriteStdout)?; + + input.consume(take_len); + bytes_written += take_len as u64; + } // Make sure we finish writing everything to the target before // exiting. Otherwise, when Rust is implicitly flushing, any // error will be silently ignored. - writer.flush().map_err(wrap_in_stdout_error)?; + writer + .flush() + .map_err(wrap_in_stdout_error) + .map_err(HeadFileError::WriteStdout)?; Ok(bytes_written) } @@ -388,15 +422,17 @@ fn head_backwards_on_seekable_file(input: &mut File, options: &HeadOptions) -> i } } -fn head_file(input: &mut File, options: &HeadOptions) -> io::Result { +fn head_file(input: &mut File, options: &HeadOptions) -> Result { match options.mode { - Mode::FirstBytes(n) => print_n_bytes(input, n), + Mode::FirstBytes(n) => print_n_bytes(input, n).map_err(HeadFileError::WriteStdout), Mode::FirstLines(n) => print_n_lines( &mut io::BufReader::with_capacity(BUF_SIZE, input), n, options.line_ending.into(), ), - Mode::AllButLastBytes(_) | Mode::AllButLastLines(_) => head_backwards_file(input, options), + Mode::AllButLastBytes(_) | Mode::AllButLastLines(_) => { + head_backwards_file(input, options).map_err(HeadFileError::WriteStdout) + } } } @@ -424,10 +460,30 @@ fn uu_head(options: &HeadOptions) -> UResult<()> { // last byte read so that any tools that parse the remainder of // the stdin stream read from the correct place. - let bytes_read = head_file(&mut stdin_file, options)?; + let bytes_read = match head_file(&mut stdin_file, options) { + Ok(bytes_read) => bytes_read, + Err(HeadFileError::Read(err)) => { + return Err(HeadError::Io { + name: "standard input".into(), + err, + } + .into()); + } + Err(HeadFileError::WriteStdout(err)) => return Err(err.into()), + }; stdin_file.seek(SeekFrom::Start(current_pos + bytes_read))?; } else { - let _bytes_read = head_file(&mut stdin_file, options)?; + match head_file(&mut stdin_file, options) { + Ok(_) => {} + Err(HeadFileError::Read(err)) => { + return Err(HeadError::Io { + name: "standard input".into(), + err, + } + .into()); + } + Err(HeadFileError::WriteStdout(err)) => return Err(err.into()), + } } } @@ -435,14 +491,31 @@ fn uu_head(options: &HeadOptions) -> UResult<()> { { let mut stdin = stdin.lock(); - match options.mode { - Mode::FirstBytes(n) => print_n_bytes(&mut stdin, n), - Mode::AllButLastBytes(n) => print_but_last_n_bytes(&mut stdin, n), + let result = match options.mode { + Mode::FirstBytes(n) => { + print_n_bytes(&mut stdin, n).map_err(HeadFileError::WriteStdout) + } + Mode::AllButLastBytes(n) => { + print_but_last_n_bytes(&mut stdin, n).map_err(HeadFileError::WriteStdout) + } Mode::FirstLines(n) => print_n_lines(&mut stdin, n, options.line_ending.into()), Mode::AllButLastLines(n) => { print_but_last_n_lines(&mut stdin, n, options.line_ending.into()) + .map_err(HeadFileError::WriteStdout) } - }?; + }; + + match result { + Ok(_) => {} + Err(HeadFileError::Read(err)) => { + return Err(HeadError::Io { + name: "standard input".into(), + err, + } + .into()); + } + Err(HeadFileError::WriteStdout(err)) => return Err(err.into()), + } } Ok(()) @@ -493,7 +566,17 @@ fn uu_head(options: &HeadOptions) -> UResult<()> { continue; } }; - head_file(&mut file_handle, options)?; + match head_file(&mut file_handle, options) { + Ok(_) => {} + Err(HeadFileError::Read(err)) => { + show!(HeadError::Io { + name: file.into(), + err + }); + continue; + } + Err(HeadFileError::WriteStdout(err)) => return Err(err.into()), + } Ok(()) }; if let Err(err) = res { diff --git a/src/uu/head/src/take.rs b/src/uu/head/src/take.rs index 552a624cc5a..57aae712415 100644 --- a/src/uu/head/src/take.rs +++ b/src/uu/head/src/take.rs @@ -303,62 +303,11 @@ pub fn copy_all_but_n_lines( Ok(total_bytes_copied) } -/// Like `std::io::Take`, but for lines instead of bytes. -/// -/// This struct is generally created by calling [`take_lines`] on a -/// reader. Please see the documentation of [`take_lines`] for more -/// details. -pub struct TakeLines { - inner: T, - limit: u64, - separator: u8, -} - -impl Read for TakeLines { - /// Read bytes from a buffer up to the requested number of lines. - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - if self.limit == 0 { - return Ok(0); - } - match self.inner.read(buf) { - Ok(0) => Ok(0), - Ok(n) => { - for i in memchr_iter(self.separator, &buf[..n]) { - self.limit -= 1; - if self.limit == 0 { - return Ok(i + 1); - } - } - Ok(n) - } - Err(e) => Err(e), - } - } -} - -/// Create an adaptor that will read at most `limit` lines from a given reader. -/// -/// This function returns a new instance of `Read` that will read at -/// most `limit` lines, after which it will always return EOF -/// (`Ok(0)`). -/// -/// The `separator` defines the character to interpret as the line -/// ending. For the usual notion of "line", set this to `b'\n'`. -pub fn take_lines(reader: R, limit: u64, separator: u8) -> TakeLines { - TakeLines { - inner: reader, - limit, - separator, - } -} - #[cfg(test)] mod tests { - use std::io::{BufRead, BufReader}; - use crate::take::{ - TakeAllBuffer, TakeAllLinesBuffer, copy_all_but_n_bytes, copy_all_but_n_lines, take_lines, + TakeAllBuffer, TakeAllLinesBuffer, copy_all_but_n_bytes, copy_all_but_n_lines, }; #[test] @@ -635,33 +584,4 @@ mod tests { assert_eq!(bytes_copied, 2); assert_eq!(output_reader.get_ref()[..], input_buffer.as_bytes()[0..2]); } - - #[test] - fn test_zero_lines() { - let input_reader = std::io::Cursor::new("a\nb\nc\n"); - let output_reader = BufReader::new(take_lines(input_reader, 0, b'\n')); - let mut iter = output_reader.lines().map(|l| l.unwrap()); - assert_eq!(None, iter.next()); - } - - #[test] - fn test_fewer_lines() { - let input_reader = std::io::Cursor::new("a\nb\nc\n"); - let output_reader = BufReader::new(take_lines(input_reader, 2, b'\n')); - let mut iter = output_reader.lines().map(|l| l.unwrap()); - assert_eq!(Some(String::from("a")), iter.next()); - assert_eq!(Some(String::from("b")), iter.next()); - assert_eq!(None, iter.next()); - } - - #[test] - fn test_more_lines() { - let input_reader = std::io::Cursor::new("a\nb\nc\n"); - let output_reader = BufReader::new(take_lines(input_reader, 4, b'\n')); - let mut iter = output_reader.lines().map(|l| l.unwrap()); - assert_eq!(Some(String::from("a")), iter.next()); - assert_eq!(Some(String::from("b")), iter.next()); - assert_eq!(Some(String::from("c")), iter.next()); - assert_eq!(None, iter.next()); - } } diff --git a/tests/by-util/test_head.rs b/tests/by-util/test_head.rs index 8406c3d9082..e6149dd18d8 100644 --- a/tests/by-util/test_head.rs +++ b/tests/by-util/test_head.rs @@ -251,6 +251,22 @@ fn test_multiple_nonexistent_files() { .stderr_contains("cannot open 'bogusfile2' for reading: No such file or directory"); } +#[test] +#[cfg(all(target_os = "linux", not(target_env = "musl")))] +#[cfg_attr(wasi_runner, ignore = "WASI sandbox: host paths not visible")] +fn test_multiple_files_read_error_continues_to_next_file() { + let ts = TestScenario::new(util_name!()); + let at = &ts.fixtures; + + at.write("a", "hello\n"); + + ts.ucmd() + .args(&["/proc/self/mem", "a"]) + .fails() + .stdout_is("==> /proc/self/mem <==\n\n==> a <==\nhello\n") + .stderr_contains("head: error reading '/proc/self/mem': Input/output error"); +} + // there was a bug not caught by previous tests // where for negative n > 3, the total amount of lines // was correct, but it would eat from the second line @@ -901,6 +917,25 @@ fn test_write_to_dev_full() { } } +#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))] +#[test] +fn test_write_to_dev_full_with_named_file() { + use std::fs::OpenOptions; + + let ts = TestScenario::new(util_name!()); + let at = &ts.fixtures; + + at.write("input", "hello\nworld\n"); + + let dev_full = OpenOptions::new().write(true).open("/dev/full").unwrap(); + + ts.ucmd() + .arg("input") + .set_stdout(dev_full) + .fails() + .stderr_is("head: error writing 'standard output': No space left on device\n"); +} + #[test] #[cfg(target_os = "linux")] #[cfg_attr(wasi_runner, ignore = "WASI: argv/filenames must be valid UTF-8")]