diff --git a/.gitignore b/.gitignore index d2feb985d38..78131290bb3 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,7 @@ docker-compose.override.yaml .aider* .env + +**/CLAUDE.local.md CLAUDE.md + diff --git a/server/channels/store/sqlstore/channel_store.go b/server/channels/store/sqlstore/channel_store.go index 684c0d31a9b..eea663f6e6b 100644 --- a/server/channels/store/sqlstore/channel_store.go +++ b/server/channels/store/sqlstore/channel_store.go @@ -421,8 +421,6 @@ type allChannelMember struct { ChannelSchemeDefaultAdminRole sql.NullString } -type allChannelMembers []allChannelMember - func (db allChannelMember) Process() (string, string) { roles := strings.Fields(db.Roles) @@ -472,17 +470,6 @@ func (db allChannelMember) Process() (string, string) { return db.ChannelId, strings.Join(roles, " ") } -func (db allChannelMembers) ToMapStringString() map[string]string { - result := make(map[string]string) - - for _, item := range db { - key, value := item.Process() - result[key] = value - } - - return result -} - // publicChannel is a subset of the metadata corresponding to public channels only. type publicChannel struct { Id string `json:"id"` @@ -2288,8 +2275,7 @@ func (s SqlChannelStore) GetAllChannelMembersForUser(rctx request.CTX, userId st } defer deferClose(rows, &err) - var data allChannelMembers - for rows.Next() { + scanner := func(rows *sql.Rows) (string, string, error) { var cm allChannelMember err = rows.Scan( &cm.ChannelId, &cm.Roles, &cm.SchemeGuest, &cm.SchemeUser, @@ -2297,17 +2283,10 @@ func (s SqlChannelStore) GetAllChannelMembersForUser(rctx request.CTX, userId st &cm.TeamSchemeDefaultAdminRole, &cm.ChannelSchemeDefaultGuestRole, &cm.ChannelSchemeDefaultUserRole, &cm.ChannelSchemeDefaultAdminRole, ) - if err != nil { - return nil, errors.Wrap(err, "unable to scan columns") - } - data = append(data, cm) + k, v := cm.Process() + return k, v, errors.Wrap(err, "unable to scan columns") } - if err = rows.Err(); err != nil { - return nil, errors.Wrap(err, "error while iterating over rows") - } - ids := data.ToMapStringString() - - return ids, nil + return scanRowsIntoMap(rows, scanner, nil) } func (s SqlChannelStore) GetChannelsMemberCount(channelIDs []string) (_ map[string]int64, err error) { @@ -2326,33 +2305,26 @@ func (s SqlChannelStore) GetChannelsMemberCount(channelIDs []string) (_ map[stri return nil, errors.Wrap(err, "channels_member_count_tosql") } - rows, err := s.GetReplica().DB.Query(queryString, args...) - + rows, err := s.GetReplica().Query(queryString, args...) if err != nil { return nil, errors.Wrap(err, "failed to fetch member counts") } defer rows.Close() - memberCounts := make(map[string]int64) - // Initialize member counts for channels with zero members + // Initialize default values map + defaults := make(map[string]int64, len(channelIDs)) for _, channelID := range channelIDs { - memberCounts[channelID] = 0 + defaults[channelID] = 0 } - for rows.Next() { + + scanner := func(rows *sql.Rows) (string, int64, error) { var channelID string var count int64 - errScan := rows.Scan(&channelID, &count) - if errScan != nil { - return nil, errors.Wrap(err, "failed to scan row") - } - memberCounts[channelID] = count + err := rows.Scan(&channelID, &count) + return channelID, count, errors.Wrap(err, "failed to scan row") } - if err = rows.Err(); err != nil { - return nil, errors.Wrap(err, "error while iterating rows") - } - - return memberCounts, nil + return scanRowsIntoMap(rows, scanner, defaults) } func (s SqlChannelStore) InvalidateCacheForChannelMembersNotifyProps(channelId string) { @@ -3050,28 +3022,25 @@ func (s SqlChannelStore) AnalyticsCountAll(teamId string) (map[model.ChannelType query = query.Where(sq.Eq{"TeamId": teamId}) } - sql, args, err := query.ToSql() + sqlStr, args, err := query.ToSql() if err != nil { return nil, errors.Wrap(err, "AnalyticsCountAll_ToSql") } - rows, err := s.GetReplica().Query(sql, args...) + rows, err := s.GetReplica().Query(sqlStr, args...) if err != nil { return nil, errors.Wrap(err, "failed to count Channels by type") } defer rows.Close() - counts := make(map[model.ChannelType]int64) - for rows.Next() { + scanner := func(rows *sql.Rows) (model.ChannelType, int64, error) { var channelType model.ChannelType var count int64 - if err := rows.Scan(&channelType, &count); err != nil { - return nil, errors.Wrap(err, "unable to scan row") - } - counts[channelType] = count + err := rows.Scan(&channelType, &count) + return channelType, count, errors.Wrap(err, "unable to scan row") } - return counts, nil + return scanRowsIntoMap(rows, scanner, nil) } func (s SqlChannelStore) GetMembersForUser(teamID string, userID string) (model.ChannelMembers, error) { diff --git a/server/channels/store/sqlstore/utils.go b/server/channels/store/sqlstore/utils.go index 9056c6f42e0..22ae4bac4d0 100644 --- a/server/channels/store/sqlstore/utils.go +++ b/server/channels/store/sqlstore/utils.go @@ -196,3 +196,27 @@ func quoteColumnName(driver string, columnName string) string { return columnName } + +// scanRowsIntoMap scans SQL rows into a map, using a provided scanner function to extract key-value pairs +func scanRowsIntoMap[K comparable, V any](rows *sql.Rows, scanner func(rows *sql.Rows) (K, V, error), defaults map[K]V) (map[K]V, error) { + results := make(map[K]V, len(defaults)) + + // Initialize with default values if provided + for k, v := range defaults { + results[k] = v + } + + for rows.Next() { + key, value, err := scanner(rows) + if err != nil { + return nil, err + } + results[key] = value + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error while iterating rows: %w", err) + } + + return results, nil +} diff --git a/server/channels/store/sqlstore/utils_test.go b/server/channels/store/sqlstore/utils_test.go index 6ec6a3b2fe2..1d3da071864 100644 --- a/server/channels/store/sqlstore/utils_test.go +++ b/server/channels/store/sqlstore/utils_test.go @@ -4,8 +4,11 @@ package sqlstore import ( + "database/sql" "testing" + "github.com/mattermost/mattermost/server/public/shared/request" + "github.com/mattermost/mattermost/server/v8/channels/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -133,3 +136,132 @@ func TestMySQLJSONArgs(t *testing.T) { assert.Equal(t, test.argString, argString) } } + +func TestScanRowsIntoMap(t *testing.T) { + StoreTest(t, func(t *testing.T, rctx request.CTX, ss store.Store) { + sqlStore := ss.(*SqlStore) + + t.Run("basic mapping", func(t *testing.T) { + // Create a test table + _, err := sqlStore.GetMaster().Exec(` + CREATE TEMPORARY TABLE IF NOT EXISTS MapTest ( + id VARCHAR(50) PRIMARY KEY, + value INT + ) + `) + require.NoError(t, err) + + // Insert test data + _, err = sqlStore.GetMaster().Exec(` + INSERT INTO MapTest VALUES ('key1', 10), ('key2', 20), ('key3', 30) + `) + require.NoError(t, err) + + // Query the data + rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTest ORDER BY id`) + require.NoError(t, err) + defer rows.Close() + + // Create scanner function + scanner := func(rows *sql.Rows) (string, int, error) { + var key string + var value int + return key, value, rows.Scan(&key, &value) + } + + // Call the function under test + result, err := scanRowsIntoMap(rows, scanner, nil) + + // Assert results + require.NoError(t, err) + require.Len(t, result, 3) + require.Equal(t, 10, result["key1"]) + require.Equal(t, 20, result["key2"]) + require.Equal(t, 30, result["key3"]) + }) + + t.Run("with default values", func(t *testing.T) { + // Create a test table + _, err := sqlStore.GetMaster().Exec(` + CREATE TEMPORARY TABLE IF NOT EXISTS MapTestDefaults ( + id VARCHAR(50) PRIMARY KEY, + value INT + ) + `) + require.NoError(t, err) + + // Insert test data - only insert one key to test defaults + _, err = sqlStore.GetMaster().Exec(` + INSERT INTO MapTestDefaults VALUES ('key1', 10) + `) + require.NoError(t, err) + + // Query the data + rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTestDefaults`) + require.NoError(t, err) + defer rows.Close() + + // Create scanner function + scanner := func(rows *sql.Rows) (string, int, error) { + var key string + var value int + return key, value, rows.Scan(&key, &value) + } + + // Define defaults + defaults := map[string]int{ + "key1": 100, // Should be overwritten + "key2": 200, // Should remain + "key3": 300, // Should remain + } + + // Call the function under test + result, err := scanRowsIntoMap(rows, scanner, defaults) + + // Assert results + require.NoError(t, err) + require.Len(t, result, 3) + require.Equal(t, 10, result["key1"]) // Should be from DB, not default + require.Equal(t, 200, result["key2"]) // Should be from defaults + require.Equal(t, 300, result["key3"]) // Should be from defaults + }) + + t.Run("with empty result set", func(t *testing.T) { + // Create a test table + _, err := sqlStore.GetMaster().Exec(` + CREATE TEMPORARY TABLE IF NOT EXISTS MapTestEmpty ( + id VARCHAR(50) PRIMARY KEY, + value INT + ) + `) + require.NoError(t, err) + + // Query the empty table + rows, err := sqlStore.GetMaster().Query(`SELECT id, value FROM MapTestEmpty`) + require.NoError(t, err) + defer rows.Close() + + // Create scanner function + scanner := func(rows *sql.Rows) (string, int, error) { + var key string + var value int + return key, value, rows.Scan(&key, &value) + } + + // Define defaults + defaults := map[string]int{ + "key1": 100, + "key2": 200, + } + + // Call the function under test + result, err := scanRowsIntoMap(rows, scanner, defaults) + + // Assert results + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, 100, result["key1"]) // Should be from defaults + require.Equal(t, 200, result["key2"]) // Should be from defaults + }) + }) +}