mirror of
https://github.com/mattermost/mattermost.git
synced 2026-02-03 20:40:00 -05:00
[AI assisted] MM-62755: Refactor scanning to map to a util (#30780)
With some neat generics, I was able to refactor the scanning to a util function. I used it to refactor 3 places and also removed an unnecessary method. Claude was quite good here. https://mattermost.atlassian.net/browse/MM-62755 ```release-note NONE ```
This commit is contained in:
parent
cd5523f5fb
commit
efde5e2717
4 changed files with 178 additions and 50 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -160,4 +160,7 @@ docker-compose.override.yaml
|
|||
.aider*
|
||||
.env
|
||||
|
||||
|
||||
**/CLAUDE.local.md
|
||||
CLAUDE.md
|
||||
|
||||
|
|
|
|||
|
|
@ -421,8 +421,6 @@ type allChannelMember struct {
|
|||
ChannelSchemeDefaultAdminRole sql.NullString
|
||||
}
|
||||
|
||||
type allChannelMembers []allChannelMember
|
||||
|
||||
func (db allChannelMember) Process() (string, string) {
|
||||
roles := strings.Fields(db.Roles)
|
||||
|
||||
|
|
@ -472,17 +470,6 @@ func (db allChannelMember) Process() (string, string) {
|
|||
return db.ChannelId, strings.Join(roles, " ")
|
||||
}
|
||||
|
||||
func (db allChannelMembers) ToMapStringString() map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
for _, item := range db {
|
||||
key, value := item.Process()
|
||||
result[key] = value
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// publicChannel is a subset of the metadata corresponding to public channels only.
|
||||
type publicChannel struct {
|
||||
Id string `json:"id"`
|
||||
|
|
@ -2288,8 +2275,7 @@ func (s SqlChannelStore) GetAllChannelMembersForUser(rctx request.CTX, userId st
|
|||
}
|
||||
defer deferClose(rows, &err)
|
||||
|
||||
var data allChannelMembers
|
||||
for rows.Next() {
|
||||
scanner := func(rows *sql.Rows) (string, string, error) {
|
||||
var cm allChannelMember
|
||||
err = rows.Scan(
|
||||
&cm.ChannelId, &cm.Roles, &cm.SchemeGuest, &cm.SchemeUser,
|
||||
|
|
@ -2297,17 +2283,10 @@ func (s SqlChannelStore) GetAllChannelMembersForUser(rctx request.CTX, userId st
|
|||
&cm.TeamSchemeDefaultAdminRole, &cm.ChannelSchemeDefaultGuestRole,
|
||||
&cm.ChannelSchemeDefaultUserRole, &cm.ChannelSchemeDefaultAdminRole,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unable to scan columns")
|
||||
}
|
||||
data = append(data, cm)
|
||||
k, v := cm.Process()
|
||||
return k, v, errors.Wrap(err, "unable to scan columns")
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, errors.Wrap(err, "error while iterating over rows")
|
||||
}
|
||||
ids := data.ToMapStringString()
|
||||
|
||||
return ids, nil
|
||||
return scanRowsIntoMap(rows, scanner, nil)
|
||||
}
|
||||
|
||||
func (s SqlChannelStore) GetChannelsMemberCount(channelIDs []string) (_ map[string]int64, err error) {
|
||||
|
|
@ -2326,33 +2305,26 @@ func (s SqlChannelStore) GetChannelsMemberCount(channelIDs []string) (_ map[stri
|
|||
return nil, errors.Wrap(err, "channels_member_count_tosql")
|
||||
}
|
||||
|
||||
rows, err := s.GetReplica().DB.Query(queryString, args...)
|
||||
|
||||
rows, err := s.GetReplica().Query(queryString, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch member counts")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
memberCounts := make(map[string]int64)
|
||||
// Initialize member counts for channels with zero members
|
||||
// Initialize default values map
|
||||
defaults := make(map[string]int64, len(channelIDs))
|
||||
for _, channelID := range channelIDs {
|
||||
memberCounts[channelID] = 0
|
||||
defaults[channelID] = 0
|
||||
}
|
||||
for rows.Next() {
|
||||
|
||||
scanner := func(rows *sql.Rows) (string, int64, error) {
|
||||
var channelID string
|
||||
var count int64
|
||||
errScan := rows.Scan(&channelID, &count)
|
||||
if errScan != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan row")
|
||||
}
|
||||
memberCounts[channelID] = count
|
||||
err := rows.Scan(&channelID, &count)
|
||||
return channelID, count, errors.Wrap(err, "failed to scan row")
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, errors.Wrap(err, "error while iterating rows")
|
||||
}
|
||||
|
||||
return memberCounts, nil
|
||||
return scanRowsIntoMap(rows, scanner, defaults)
|
||||
}
|
||||
|
||||
func (s SqlChannelStore) InvalidateCacheForChannelMembersNotifyProps(channelId string) {
|
||||
|
|
@ -3050,28 +3022,25 @@ func (s SqlChannelStore) AnalyticsCountAll(teamId string) (map[model.ChannelType
|
|||
query = query.Where(sq.Eq{"TeamId": teamId})
|
||||
}
|
||||
|
||||
sql, args, err := query.ToSql()
|
||||
sqlStr, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "AnalyticsCountAll_ToSql")
|
||||
}
|
||||
|
||||
rows, err := s.GetReplica().Query(sql, args...)
|
||||
rows, err := s.GetReplica().Query(sqlStr, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to count Channels by type")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
counts := make(map[model.ChannelType]int64)
|
||||
for rows.Next() {
|
||||
scanner := func(rows *sql.Rows) (model.ChannelType, int64, error) {
|
||||
var channelType model.ChannelType
|
||||
var count int64
|
||||
if err := rows.Scan(&channelType, &count); err != nil {
|
||||
return nil, errors.Wrap(err, "unable to scan row")
|
||||
}
|
||||
counts[channelType] = count
|
||||
err := rows.Scan(&channelType, &count)
|
||||
return channelType, count, errors.Wrap(err, "unable to scan row")
|
||||
}
|
||||
|
||||
return counts, nil
|
||||
return scanRowsIntoMap(rows, scanner, nil)
|
||||
}
|
||||
|
||||
func (s SqlChannelStore) GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error) {
|
||||
|
|
|
|||
|
|
@ -196,3 +196,27 @@ func quoteColumnName(driver string, columnName string) string {
|
|||
|
||||
return columnName
|
||||
}
|
||||
|
||||
// scanRowsIntoMap scans SQL rows into a map, using a provided scanner function to extract key-value pairs
|
||||
func scanRowsIntoMap[K comparable, V any](rows *sql.Rows, scanner func(rows *sql.Rows) (K, V, error), defaults map[K]V) (map[K]V, error) {
|
||||
results := make(map[K]V, len(defaults))
|
||||
|
||||
// Initialize with default values if provided
|
||||
for k, v := range defaults {
|
||||
results[k] = v
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
key, value, err := scanner(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results[key] = value
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error while iterating rows: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,11 @@
|
|||
package sqlstore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/mattermost/mattermost/server/public/shared/request"
|
||||
"github.com/mattermost/mattermost/server/v8/channels/store"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
|
@ -133,3 +136,132 @@ func TestMySQLJSONArgs(t *testing.T) {
|
|||
assert.Equal(t, test.argString, argString)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanRowsIntoMap(t *testing.T) {
|
||||
StoreTest(t, func(t *testing.T, rctx request.CTX, ss store.Store) {
|
||||
sqlStore := ss.(*SqlStore)
|
||||
|
||||
t.Run("basic mapping", func(t *testing.T) {
|
||||
// Create a test table
|
||||
_, err := sqlStore.GetMaster().Exec(`
|
||||
CREATE TEMPORARY TABLE IF NOT EXISTS MapTest (
|
||||
id VARCHAR(50) PRIMARY KEY,
|
||||
value INT
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert test data
|
||||
_, err = sqlStore.GetMaster().Exec(`
|
||||
INSERT INTO MapTest VALUES ('key1', 10), ('key2', 20), ('key3', 30)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query the data
|
||||
rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTest ORDER BY id`)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
// Create scanner function
|
||||
scanner := func(rows *sql.Rows) (string, int, error) {
|
||||
var key string
|
||||
var value int
|
||||
return key, value, rows.Scan(&key, &value)
|
||||
}
|
||||
|
||||
// Call the function under test
|
||||
result, err := scanRowsIntoMap(rows, scanner, nil)
|
||||
|
||||
// Assert results
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 3)
|
||||
require.Equal(t, 10, result["key1"])
|
||||
require.Equal(t, 20, result["key2"])
|
||||
require.Equal(t, 30, result["key3"])
|
||||
})
|
||||
|
||||
t.Run("with default values", func(t *testing.T) {
|
||||
// Create a test table
|
||||
_, err := sqlStore.GetMaster().Exec(`
|
||||
CREATE TEMPORARY TABLE IF NOT EXISTS MapTestDefaults (
|
||||
id VARCHAR(50) PRIMARY KEY,
|
||||
value INT
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Insert test data - only insert one key to test defaults
|
||||
_, err = sqlStore.GetMaster().Exec(`
|
||||
INSERT INTO MapTestDefaults VALUES ('key1', 10)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query the data
|
||||
rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTestDefaults`)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
// Create scanner function
|
||||
scanner := func(rows *sql.Rows) (string, int, error) {
|
||||
var key string
|
||||
var value int
|
||||
return key, value, rows.Scan(&key, &value)
|
||||
}
|
||||
|
||||
// Define defaults
|
||||
defaults := map[string]int{
|
||||
"key1": 100, // Should be overwritten
|
||||
"key2": 200, // Should remain
|
||||
"key3": 300, // Should remain
|
||||
}
|
||||
|
||||
// Call the function under test
|
||||
result, err := scanRowsIntoMap(rows, scanner, defaults)
|
||||
|
||||
// Assert results
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 3)
|
||||
require.Equal(t, 10, result["key1"]) // Should be from DB, not default
|
||||
require.Equal(t, 200, result["key2"]) // Should be from defaults
|
||||
require.Equal(t, 300, result["key3"]) // Should be from defaults
|
||||
})
|
||||
|
||||
t.Run("with empty result set", func(t *testing.T) {
|
||||
// Create a test table
|
||||
_, err := sqlStore.GetMaster().Exec(`
|
||||
CREATE TEMPORARY TABLE IF NOT EXISTS MapTestEmpty (
|
||||
id VARCHAR(50) PRIMARY KEY,
|
||||
value INT
|
||||
)
|
||||
`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Query the empty table
|
||||
rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTestEmpty`)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
// Create scanner function
|
||||
scanner := func(rows *sql.Rows) (string, int, error) {
|
||||
var key string
|
||||
var value int
|
||||
return key, value, rows.Scan(&key, &value)
|
||||
}
|
||||
|
||||
// Define defaults
|
||||
defaults := map[string]int{
|
||||
"key1": 100,
|
||||
"key2": 200,
|
||||
}
|
||||
|
||||
// Call the function under test
|
||||
result, err := scanRowsIntoMap(rows, scanner, defaults)
|
||||
|
||||
// Assert results
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 2)
|
||||
require.Equal(t, 100, result["key1"]) // Should be from defaults
|
||||
require.Equal(t, 200, result["key2"]) // Should be from defaults
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue