Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ func (s *stmt) Close() error {
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { panic("deprecated, unused") }
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { panic("deprecated, unused") }

func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, errRet error) {
if s.closed.Load() {
UsesAfterClose.Add("stmt.ExecContext", 1)
return nil, ErrClosed
Expand All @@ -566,7 +566,10 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
return nil, s.reserr("Stmt.Exec(Bind)", err)
}
if s.conn.logger != nil && !s.conn.readOnly {
s.conn.logger.Statement(s.stmt.ExpandedSQL())
esql := s.stmt.ExpandedSQL()
defer func() {
s.conn.logger.Statement(esql, errRet)
}()
}

if ctx.Value(queryCancelKey{}) != nil {
Expand Down Expand Up @@ -1196,7 +1199,8 @@ type ConnLogger interface {
Begin()

// Statement is called with evaluated SQL when a statement is executed.
Statement(sql string)
// err is the error (if any) resulting from executing the statement.
Statement(sql string, err error)

// Commit is called after a commit statement, with the error resulting
// from the attempted commit.
Expand Down
62 changes: 42 additions & 20 deletions sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1416,24 +1416,37 @@ func TestDisableFunction(t *testing.T) {
}
}

type statements struct {
succeeded []string
failed []string
}

func (s statements) String() string {
return fmt.Sprintf("succeeded\n------------\n%s\n\nfailed\n------------\n%s", strings.Join(s.succeeded, "\n"), strings.Join(s.failed, "\n"))
}

type connLogger struct {
ch chan []string
statements []string
ch chan statements
statements statements
panicOnUse bool
}

func (cl *connLogger) Begin() {
if cl.panicOnUse {
panic("unexpected connLogger.Begin()")
}
cl.statements = nil
cl.statements = statements{}
}

func (cl *connLogger) Statement(s string) {
func (cl *connLogger) Statement(s string, err error) {
if cl.panicOnUse {
panic("unexpected connLogger.Statement: " + s)
}
cl.statements = append(cl.statements, s)
if err == nil {
cl.statements.succeeded = append(cl.statements.succeeded, s)
} else {
cl.statements.failed = append(cl.statements.failed, s)
}
}

func (cl *connLogger) Commit(err error) {
Expand All @@ -1450,7 +1463,7 @@ func (cl *connLogger) Rollback() {
if cl.panicOnUse {
panic("unexpected connLogger.Rollback()")
}
cl.statements = nil
cl.statements = statements{}
}

func TestConnLogger_writable(t *testing.T) {
Expand All @@ -1461,7 +1474,7 @@ func TestConnLogger_writable(t *testing.T) {
}
t.Run(doneStatement, func(t *testing.T) {
ctx := context.Background()
ch := make(chan []string, 1)
ch := make(chan statements, 1)
txl := connLogger{ch: ch}
makeLogger := func() ConnLogger { return &txl }
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
Expand All @@ -1471,7 +1484,7 @@ func TestConnLogger_writable(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
if _, err := tx.Exec("CREATE TABLE T (x INTEGER UNIQUE)"); err != nil {
t.Fatal(err)
}
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
Expand All @@ -1480,6 +1493,10 @@ func TestConnLogger_writable(t *testing.T) {
if _, err := tx.Query("SELECT x FROM T"); err != nil {
t.Fatal(err)
}
// the below should fail because T already contains value 1
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err == nil {
t.Fatal("unique constraint violation should have failed")
}
done := tx.Rollback
if commit {
done = tx.Commit
Expand All @@ -1490,22 +1507,27 @@ func TestConnLogger_writable(t *testing.T) {
if !commit {
select {
case got := <-ch:
t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n"))
t.Errorf("unexpectedly logged statements for rollback:\n%s", got)
default:
return
}
}

want := []string{
"BEGIN IMMEDIATE",
"CREATE TABLE T (x INTEGER)",
"INSERT INTO T VALUES (1)",
doneStatement,
want := statements{
succeeded: []string{
"BEGIN IMMEDIATE",
"CREATE TABLE T (x INTEGER UNIQUE)",
"INSERT INTO T VALUES (1)",
doneStatement,
},
failed: []string{
"INSERT INTO T VALUES (1)",
},
}
select {
case got := <-ch:
if !slices.Equal(got, want) {
t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n"))
if !slices.Equal(got.succeeded, want.succeeded) || !slices.Equal(got.failed, want.failed) {
t.Errorf("unexpected log statements. got:\n%s\nwant:\n%s", got, want)
}
default:
t.Fatal("no logged statements after commit")
Expand All @@ -1516,7 +1538,7 @@ func TestConnLogger_writable(t *testing.T) {

func TestConnLogger_commit_error(t *testing.T) {
ctx := context.Background()
ch := make(chan []string, 1)
ch := make(chan statements, 1)
txl := connLogger{ch: ch}
makeLogger := func() ConnLogger { return &txl }
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
Expand Down Expand Up @@ -1544,15 +1566,15 @@ func TestConnLogger_commit_error(t *testing.T) {
}
select {
case got := <-ch:
t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n"))
t.Errorf("unexpectedly logged statements for errored commit:\n%s", got)
default:
return
}
}

func TestConnLogger_read_tx(t *testing.T) {
ctx := context.Background()
ch := make(chan []string, 1)
ch := make(chan statements, 1)
txl := connLogger{ch: ch}
makeLogger := func() ConnLogger { return &txl }
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
Expand All @@ -1573,7 +1595,7 @@ func TestConnLogger_read_tx(t *testing.T) {
}
select {
case got := <-ch:
if len(got) == 0 {
if len(got.succeeded) == 0 {
t.Errorf("expected logged statements for write tx")
}
default:
Expand Down
Loading