Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 101 additions & 18 deletions src/uu/head/src/head.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// spell-checker:ignore (vars) memrchr

use clap::ArgMatches;
use memchr::memrchr_iter;
use memchr::{memchr_iter, memrchr_iter};
use std::ffi::OsString;
use std::fs::File;
use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
Expand All @@ -31,7 +31,6 @@ mod parse;
mod take;
use take::copy_all_but_n_bytes;
use take::copy_all_but_n_lines;
use take::take_lines;

#[derive(Error, Debug)]
enum HeadError {
Expand Down Expand Up @@ -193,21 +192,56 @@ fn print_n_bytes(input: impl Read, n: u64) -> io::Result<u64> {
Ok(bytes_written)
}

fn print_n_lines(input: &mut impl io::BufRead, n: u64, separator: u8) -> io::Result<u64> {
// Read the first `n` lines from the `input` reader.
let mut reader = take_lines(input, n, separator);
enum HeadFileError {
Read(io::Error),
WriteStdout(io::Error),
}

// Write those bytes to `stdout`.
fn print_n_lines(
input: &mut impl io::BufRead,
n: u64,
separator: u8,
) -> Result<u64, HeadFileError> {
let stdout = io::stdout();
let stdout = stdout.lock();
let mut writer = BufWriter::with_capacity(BUF_SIZE, stdout);

let bytes_written = io::copy(&mut reader, &mut writer).map_err(wrap_in_stdout_error)?;
let mut bytes_written: u64 = 0;
let mut remaining = n;
while remaining > 0 {
let chunk = input.fill_buf().map_err(HeadFileError::Read)?;

if chunk.is_empty() {
break;
}

let mut take_len = chunk.len(); // default: take everything
let mut separators_seen: u64 = 0;
for separator_idx in memchr_iter(separator, chunk) {
separators_seen += 1;
if separators_seen == remaining {
take_len = separator_idx + 1; // include the separator itself
break;
}
}
remaining -= separators_seen;

writer
.write_all(&chunk[..take_len])
.map_err(wrap_in_stdout_error)
.map_err(HeadFileError::WriteStdout)?;

input.consume(take_len);
bytes_written += take_len as u64;
}

// Make sure we finish writing everything to the target before
// exiting. Otherwise, when Rust is implicitly flushing, any
// error will be silently ignored.
writer.flush().map_err(wrap_in_stdout_error)?;
writer
.flush()
.map_err(wrap_in_stdout_error)
.map_err(HeadFileError::WriteStdout)?;

Ok(bytes_written)
}
Expand Down Expand Up @@ -388,15 +422,17 @@ fn head_backwards_on_seekable_file(input: &mut File, options: &HeadOptions) -> i
}
}

fn head_file(input: &mut File, options: &HeadOptions) -> io::Result<u64> {
fn head_file(input: &mut File, options: &HeadOptions) -> Result<u64, HeadFileError> {
match options.mode {
Mode::FirstBytes(n) => print_n_bytes(input, n),
Mode::FirstBytes(n) => print_n_bytes(input, n).map_err(HeadFileError::WriteStdout),
Mode::FirstLines(n) => print_n_lines(
&mut io::BufReader::with_capacity(BUF_SIZE, input),
n,
options.line_ending.into(),
),
Mode::AllButLastBytes(_) | Mode::AllButLastLines(_) => head_backwards_file(input, options),
Mode::AllButLastBytes(_) | Mode::AllButLastLines(_) => {
head_backwards_file(input, options).map_err(HeadFileError::WriteStdout)
}
}
}

Expand Down Expand Up @@ -424,25 +460,62 @@ fn uu_head(options: &HeadOptions) -> UResult<()> {
// last byte read so that any tools that parse the remainder of
// the stdin stream read from the correct place.

let bytes_read = head_file(&mut stdin_file, options)?;
let bytes_read = match head_file(&mut stdin_file, options) {
Ok(bytes_read) => bytes_read,
Err(HeadFileError::Read(err)) => {
return Err(HeadError::Io {
name: "standard input".into(),
err,
}
.into());
}
Err(HeadFileError::WriteStdout(err)) => return Err(err.into()),
};
stdin_file.seek(SeekFrom::Start(current_pos + bytes_read))?;
} else {
let _bytes_read = head_file(&mut stdin_file, options)?;
match head_file(&mut stdin_file, options) {
Ok(_) => {}
Err(HeadFileError::Read(err)) => {
return Err(HeadError::Io {
name: "standard input".into(),
err,
}
.into());
}
Err(HeadFileError::WriteStdout(err)) => return Err(err.into()),
}
}
}

#[cfg(not(unix))]
{
let mut stdin = stdin.lock();

match options.mode {
Mode::FirstBytes(n) => print_n_bytes(&mut stdin, n),
Mode::AllButLastBytes(n) => print_but_last_n_bytes(&mut stdin, n),
let result = match options.mode {
Mode::FirstBytes(n) => {
print_n_bytes(&mut stdin, n).map_err(HeadFileError::WriteStdout)
}
Mode::AllButLastBytes(n) => {
print_but_last_n_bytes(&mut stdin, n).map_err(HeadFileError::WriteStdout)
}
Mode::FirstLines(n) => print_n_lines(&mut stdin, n, options.line_ending.into()),
Mode::AllButLastLines(n) => {
print_but_last_n_lines(&mut stdin, n, options.line_ending.into())
.map_err(HeadFileError::WriteStdout)
}
}?;
};

match result {
Ok(_) => {}
Err(HeadFileError::Read(err)) => {
return Err(HeadError::Io {
name: "standard input".into(),
err,
}
.into());
}
Err(HeadFileError::WriteStdout(err)) => return Err(err.into()),
}
}

Ok(())
Expand Down Expand Up @@ -493,7 +566,17 @@ fn uu_head(options: &HeadOptions) -> UResult<()> {
continue;
}
};
head_file(&mut file_handle, options)?;
match head_file(&mut file_handle, options) {
Ok(_) => {}
Err(HeadFileError::Read(err)) => {
show!(HeadError::Io {
name: file.into(),
err
});
continue;
}
Err(HeadFileError::WriteStdout(err)) => return Err(err.into()),
}
Ok(())
};
if let Err(err) = res {
Expand Down
82 changes: 1 addition & 81 deletions src/uu/head/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,62 +303,11 @@ pub fn copy_all_but_n_lines<R: Read, W: Write>(
Ok(total_bytes_copied)
}

/// Like `std::io::Take`, but for lines instead of bytes.
///
/// This struct is generally created by calling [`take_lines`] on a
/// reader. Please see the documentation of [`take_lines`] for more
/// details.
pub struct TakeLines<T> {
inner: T,
limit: u64,
separator: u8,
}

impl<T: Read> Read for TakeLines<T> {
/// Read bytes from a buffer up to the requested number of lines.
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.limit == 0 {
return Ok(0);
}
match self.inner.read(buf) {
Ok(0) => Ok(0),
Ok(n) => {
for i in memchr_iter(self.separator, &buf[..n]) {
self.limit -= 1;
if self.limit == 0 {
return Ok(i + 1);
}
}
Ok(n)
}
Err(e) => Err(e),
}
}
}

/// Create an adaptor that will read at most `limit` lines from a given reader.
///
/// This function returns a new instance of `Read` that will read at
/// most `limit` lines, after which it will always return EOF
/// (`Ok(0)`).
///
/// The `separator` defines the character to interpret as the line
/// ending. For the usual notion of "line", set this to `b'\n'`.
pub fn take_lines<R>(reader: R, limit: u64, separator: u8) -> TakeLines<R> {
TakeLines {
inner: reader,
limit,
separator,
}
}

#[cfg(test)]
mod tests {

use std::io::{BufRead, BufReader};

use crate::take::{
TakeAllBuffer, TakeAllLinesBuffer, copy_all_but_n_bytes, copy_all_but_n_lines, take_lines,
TakeAllBuffer, TakeAllLinesBuffer, copy_all_but_n_bytes, copy_all_but_n_lines,
};

#[test]
Expand Down Expand Up @@ -635,33 +584,4 @@ mod tests {
assert_eq!(bytes_copied, 2);
assert_eq!(output_reader.get_ref()[..], input_buffer.as_bytes()[0..2]);
}

#[test]
fn test_zero_lines() {
let input_reader = std::io::Cursor::new("a\nb\nc\n");
let output_reader = BufReader::new(take_lines(input_reader, 0, b'\n'));
let mut iter = output_reader.lines().map(|l| l.unwrap());
assert_eq!(None, iter.next());
}

#[test]
fn test_fewer_lines() {
let input_reader = std::io::Cursor::new("a\nb\nc\n");
let output_reader = BufReader::new(take_lines(input_reader, 2, b'\n'));
let mut iter = output_reader.lines().map(|l| l.unwrap());
assert_eq!(Some(String::from("a")), iter.next());
assert_eq!(Some(String::from("b")), iter.next());
assert_eq!(None, iter.next());
}

#[test]
fn test_more_lines() {
let input_reader = std::io::Cursor::new("a\nb\nc\n");
let output_reader = BufReader::new(take_lines(input_reader, 4, b'\n'));
let mut iter = output_reader.lines().map(|l| l.unwrap());
assert_eq!(Some(String::from("a")), iter.next());
assert_eq!(Some(String::from("b")), iter.next());
assert_eq!(Some(String::from("c")), iter.next());
assert_eq!(None, iter.next());
}
}
35 changes: 35 additions & 0 deletions tests/by-util/test_head.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,22 @@ fn test_multiple_nonexistent_files() {
.stderr_contains("cannot open 'bogusfile2' for reading: No such file or directory");
}

#[test]
#[cfg(all(target_os = "linux", not(target_env = "musl")))]
#[cfg_attr(wasi_runner, ignore = "WASI sandbox: host paths not visible")]
fn test_multiple_files_read_error_continues_to_next_file() {
let ts = TestScenario::new(util_name!());
let at = &ts.fixtures;

at.write("a", "hello\n");

ts.ucmd()
.args(&["/proc/self/mem", "a"])
.fails()
.stdout_is("==> /proc/self/mem <==\n\n==> a <==\nhello\n")
.stderr_contains("head: error reading '/proc/self/mem': Input/output error");
}

// there was a bug not caught by previous tests
// where for negative n > 3, the total amount of lines
// was correct, but it would eat from the second line
Expand Down Expand Up @@ -901,6 +917,25 @@ fn test_write_to_dev_full() {
}
}

#[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "netbsd"))]
#[test]
fn test_write_to_dev_full_with_named_file() {
use std::fs::OpenOptions;

let ts = TestScenario::new(util_name!());
let at = &ts.fixtures;

at.write("input", "hello\nworld\n");

let dev_full = OpenOptions::new().write(true).open("/dev/full").unwrap();

ts.ucmd()
.arg("input")
.set_stdout(dev_full)
.fails()
.stderr_is("head: error writing 'standard output': No space left on device\n");
}

#[test]
#[cfg(target_os = "linux")]
#[cfg_attr(wasi_runner, ignore = "WASI: argv/filenames must be valid UTF-8")]
Expand Down
Loading