[AI assisted] MM-62755: Refactor scanning to map to a util (#30780)

With some neat generics, I was able to refactor
the scanning to a util function. I used it to
refactor 3 places and also removed an unnecessary method.

Claude was quite good here.

https://mattermost.atlassian.net/browse/MM-62755
```release-note
NONE
```
This commit is contained in:
Agniva De Sarker 2025-04-28 19:21:12 +05:30 committed by GitHub
parent cd5523f5fb
commit efde5e2717
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 178 additions and 50 deletions

3
.gitignore vendored
View file

@ -160,4 +160,7 @@ docker-compose.override.yaml
.aider*
.env
**/CLAUDE.local.md
CLAUDE.md

View file

@ -421,8 +421,6 @@ type allChannelMember struct {
ChannelSchemeDefaultAdminRole sql.NullString
}
type allChannelMembers []allChannelMember
func (db allChannelMember) Process() (string, string) {
roles := strings.Fields(db.Roles)
@ -472,17 +470,6 @@ func (db allChannelMember) Process() (string, string) {
return db.ChannelId, strings.Join(roles, " ")
}
func (db allChannelMembers) ToMapStringString() map[string]string {
result := make(map[string]string)
for _, item := range db {
key, value := item.Process()
result[key] = value
}
return result
}
// publicChannel is a subset of the metadata corresponding to public channels only.
type publicChannel struct {
Id string `json:"id"`
@ -2288,8 +2275,7 @@ func (s SqlChannelStore) GetAllChannelMembersForUser(rctx request.CTX, userId st
}
defer deferClose(rows, &err)
var data allChannelMembers
for rows.Next() {
scanner := func(rows *sql.Rows) (string, string, error) {
var cm allChannelMember
err = rows.Scan(
&cm.ChannelId, &cm.Roles, &cm.SchemeGuest, &cm.SchemeUser,
@ -2297,17 +2283,10 @@ func (s SqlChannelStore) GetAllChannelMembersForUser(rctx request.CTX, userId st
&cm.TeamSchemeDefaultAdminRole, &cm.ChannelSchemeDefaultGuestRole,
&cm.ChannelSchemeDefaultUserRole, &cm.ChannelSchemeDefaultAdminRole,
)
if err != nil {
return nil, errors.Wrap(err, "unable to scan columns")
}
data = append(data, cm)
k, v := cm.Process()
return k, v, errors.Wrap(err, "unable to scan columns")
}
if err = rows.Err(); err != nil {
return nil, errors.Wrap(err, "error while iterating over rows")
}
ids := data.ToMapStringString()
return ids, nil
return scanRowsIntoMap(rows, scanner, nil)
}
func (s SqlChannelStore) GetChannelsMemberCount(channelIDs []string) (_ map[string]int64, err error) {
@ -2326,33 +2305,26 @@ func (s SqlChannelStore) GetChannelsMemberCount(channelIDs []string) (_ map[stri
return nil, errors.Wrap(err, "channels_member_count_tosql")
}
rows, err := s.GetReplica().DB.Query(queryString, args...)
rows, err := s.GetReplica().Query(queryString, args...)
if err != nil {
return nil, errors.Wrap(err, "failed to fetch member counts")
}
defer rows.Close()
memberCounts := make(map[string]int64)
// Initialize member counts for channels with zero members
// Initialize default values map
defaults := make(map[string]int64, len(channelIDs))
for _, channelID := range channelIDs {
memberCounts[channelID] = 0
defaults[channelID] = 0
}
for rows.Next() {
scanner := func(rows *sql.Rows) (string, int64, error) {
var channelID string
var count int64
errScan := rows.Scan(&channelID, &count)
if errScan != nil {
return nil, errors.Wrap(err, "failed to scan row")
}
memberCounts[channelID] = count
err := rows.Scan(&channelID, &count)
return channelID, count, errors.Wrap(err, "failed to scan row")
}
if err = rows.Err(); err != nil {
return nil, errors.Wrap(err, "error while iterating rows")
}
return memberCounts, nil
return scanRowsIntoMap(rows, scanner, defaults)
}
func (s SqlChannelStore) InvalidateCacheForChannelMembersNotifyProps(channelId string) {
@ -3050,28 +3022,25 @@ func (s SqlChannelStore) AnalyticsCountAll(teamId string) (map[model.ChannelType
query = query.Where(sq.Eq{"TeamId": teamId})
}
sql, args, err := query.ToSql()
sqlStr, args, err := query.ToSql()
if err != nil {
return nil, errors.Wrap(err, "AnalyticsCountAll_ToSql")
}
rows, err := s.GetReplica().Query(sql, args...)
rows, err := s.GetReplica().Query(sqlStr, args...)
if err != nil {
return nil, errors.Wrap(err, "failed to count Channels by type")
}
defer rows.Close()
counts := make(map[model.ChannelType]int64)
for rows.Next() {
scanner := func(rows *sql.Rows) (model.ChannelType, int64, error) {
var channelType model.ChannelType
var count int64
if err := rows.Scan(&channelType, &count); err != nil {
return nil, errors.Wrap(err, "unable to scan row")
}
counts[channelType] = count
err := rows.Scan(&channelType, &count)
return channelType, count, errors.Wrap(err, "unable to scan row")
}
return counts, nil
return scanRowsIntoMap(rows, scanner, nil)
}
func (s SqlChannelStore) GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error) {

View file

@ -196,3 +196,27 @@ func quoteColumnName(driver string, columnName string) string {
return columnName
}
// scanRowsIntoMap scans SQL rows into a map, using a provided scanner function to extract key-value pairs
func scanRowsIntoMap[K comparable, V any](rows *sql.Rows, scanner func(rows *sql.Rows) (K, V, error), defaults map[K]V) (map[K]V, error) {
results := make(map[K]V, len(defaults))
// Initialize with default values if provided
for k, v := range defaults {
results[k] = v
}
for rows.Next() {
key, value, err := scanner(rows)
if err != nil {
return nil, err
}
results[key] = value
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error while iterating rows: %w", err)
}
return results, nil
}

View file

@ -4,8 +4,11 @@
package sqlstore
import (
"database/sql"
"testing"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -133,3 +136,132 @@ func TestMySQLJSONArgs(t *testing.T) {
assert.Equal(t, test.argString, argString)
}
}
func TestScanRowsIntoMap(t *testing.T) {
StoreTest(t, func(t *testing.T, rctx request.CTX, ss store.Store) {
sqlStore := ss.(*SqlStore)
t.Run("basic mapping", func(t *testing.T) {
// Create a test table
_, err := sqlStore.GetMaster().Exec(`
CREATE TEMPORARY TABLE IF NOT EXISTS MapTest (
id VARCHAR(50) PRIMARY KEY,
value INT
)
`)
require.NoError(t, err)
// Insert test data
_, err = sqlStore.GetMaster().Exec(`
INSERT INTO MapTest VALUES ('key1', 10), ('key2', 20), ('key3', 30)
`)
require.NoError(t, err)
// Query the data
rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTest ORDER BY id`)
require.NoError(t, err)
defer rows.Close()
// Create scanner function
scanner := func(rows *sql.Rows) (string, int, error) {
var key string
var value int
return key, value, rows.Scan(&key, &value)
}
// Call the function under test
result, err := scanRowsIntoMap(rows, scanner, nil)
// Assert results
require.NoError(t, err)
require.Len(t, result, 3)
require.Equal(t, 10, result["key1"])
require.Equal(t, 20, result["key2"])
require.Equal(t, 30, result["key3"])
})
t.Run("with default values", func(t *testing.T) {
// Create a test table
_, err := sqlStore.GetMaster().Exec(`
CREATE TEMPORARY TABLE IF NOT EXISTS MapTestDefaults (
id VARCHAR(50) PRIMARY KEY,
value INT
)
`)
require.NoError(t, err)
// Insert test data - only insert one key to test defaults
_, err = sqlStore.GetMaster().Exec(`
INSERT INTO MapTestDefaults VALUES ('key1', 10)
`)
require.NoError(t, err)
// Query the data
rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTestDefaults`)
require.NoError(t, err)
defer rows.Close()
// Create scanner function
scanner := func(rows *sql.Rows) (string, int, error) {
var key string
var value int
return key, value, rows.Scan(&key, &value)
}
// Define defaults
defaults := map[string]int{
"key1": 100, // Should be overwritten
"key2": 200, // Should remain
"key3": 300, // Should remain
}
// Call the function under test
result, err := scanRowsIntoMap(rows, scanner, defaults)
// Assert results
require.NoError(t, err)
require.Len(t, result, 3)
require.Equal(t, 10, result["key1"]) // Should be from DB, not default
require.Equal(t, 200, result["key2"]) // Should be from defaults
require.Equal(t, 300, result["key3"]) // Should be from defaults
})
t.Run("with empty result set", func(t *testing.T) {
// Create a test table
_, err := sqlStore.GetMaster().Exec(`
CREATE TEMPORARY TABLE IF NOT EXISTS MapTestEmpty (
id VARCHAR(50) PRIMARY KEY,
value INT
)
`)
require.NoError(t, err)
// Query the empty table
rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTestEmpty`)
require.NoError(t, err)
defer rows.Close()
// Create scanner function
scanner := func(rows *sql.Rows) (string, int, error) {
var key string
var value int
return key, value, rows.Scan(&key, &value)
}
// Define defaults
defaults := map[string]int{
"key1": 100,
"key2": 200,
}
// Call the function under test
result, err := scanRowsIntoMap(rows, scanner, defaults)
// Assert results
require.NoError(t, err)
require.Len(t, result, 2)
require.Equal(t, 100, result["key1"]) // Should be from defaults
require.Equal(t, 200, result["key2"]) // Should be from defaults
})
})
}