177 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
			
		
		
	
	
			177 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
| # frozen_string_literal: true
 | |
| 
 | |
| module Mastodon::Snowflake
 | |
|   DEFAULT_REGEX = /timestamp_id\('(?<seq_prefix>\w+)'/
 | |
| 
 | |
|   class Callbacks
 | |
|     def self.around_create(record)
 | |
|       now = Time.now.utc
 | |
| 
 | |
|       if record.created_at.nil? || record.created_at >= now || record.created_at == record.updated_at || record.override_timestamps
 | |
|         yield
 | |
|       else
 | |
|         record.id = Mastodon::Snowflake.id_at(record.created_at)
 | |
|         tries     = 0
 | |
| 
 | |
|         begin
 | |
|           yield
 | |
|         rescue ActiveRecord::RecordNotUnique
 | |
|           raise if tries > 100
 | |
| 
 | |
|           tries     += 1
 | |
|           record.id += rand(100)
 | |
| 
 | |
|           retry
 | |
|         end
 | |
|       end
 | |
|     end
 | |
|   end
 | |
| 
 | |
|   class << self
 | |
|     # Our ID will be composed of the following:
 | |
|     # 6 bytes (48 bits) of millisecond-level timestamp
 | |
|     # 2 bytes (16 bits) of sequence data
 | |
|     #
 | |
|     # The 'sequence data' is intended to be unique within a
 | |
|     # given millisecond, yet obscure the 'serial number' of
 | |
|     # this row.
 | |
|     #
 | |
|     # To do this, we hash the following data:
 | |
|     # * Table name (if provided, skipped if not)
 | |
|     # * Secret salt (should not be guessable)
 | |
|     # * Timestamp (again, millisecond-level granularity)
 | |
|     #
 | |
|     # We then take the first two bytes of that value, and add
 | |
|     # the lowest two bytes of the table ID sequence number
 | |
|     # (`table_name`_id_seq). This means that even if we insert
 | |
|     # two rows at the same millisecond, they will have
 | |
|     # distinct 'sequence data' portions.
 | |
|     #
 | |
|     # If this happens, and an attacker can see both such IDs,
 | |
|     # they can determine which of the two entries was inserted
 | |
|     # first, but not the total number of entries in the table
 | |
|     # (even mod 2**16).
 | |
|     #
 | |
|     # The table name is included in the hash to ensure that
 | |
|     # different tables derive separate sequence bases so rows
 | |
|     # inserted in the same millisecond in different tables do
 | |
|     # not reveal the table ID sequence number for one another.
 | |
|     #
 | |
|     # The secret salt is included in the hash to ensure that
 | |
|     # external users cannot derive the sequence base given the
 | |
|     # timestamp and table name, which would allow them to
 | |
|     # compute the table ID sequence number.
 | |
|     def define_timestamp_id
 | |
|       return if already_defined?
 | |
| 
 | |
|       connection.execute(sanitized_timestamp_id_sql)
 | |
|     end
 | |
| 
 | |
|     def ensure_id_sequences_exist
 | |
|       # Find tables using timestamp IDs.
 | |
|       connection.tables.each do |table|
 | |
|         # We're only concerned with "id" columns.
 | |
|         next unless (id_col = connection.columns(table).find { |col| col.name == 'id' })
 | |
| 
 | |
|         # And only those that are using timestamp_id.
 | |
|         next unless (data = DEFAULT_REGEX.match(id_col.default_function))
 | |
| 
 | |
|         seq_name = "#{data[:seq_prefix]}_id_seq"
 | |
| 
 | |
|         # If we were on Postgres 9.5+, we could do CREATE SEQUENCE IF
 | |
|         # NOT EXISTS, but we can't depend on that. Instead, catch the
 | |
|         # possible exception and ignore it.
 | |
|         # Note that seq_name isn't a column name, but it's a
 | |
|         # relation, like a column, and follows the same quoting rules
 | |
|         # in Postgres.
 | |
|         connection.execute(<<~SQL)
 | |
|           DO $$
 | |
|             BEGIN
 | |
|               CREATE SEQUENCE #{connection.quote_column_name(seq_name)};
 | |
|             EXCEPTION WHEN duplicate_table THEN
 | |
|               -- Do nothing, we have the sequence already.
 | |
|             END
 | |
|           $$ LANGUAGE plpgsql;
 | |
|         SQL
 | |
|       end
 | |
|     end
 | |
| 
 | |
|     def id_at(timestamp, with_random: true)
 | |
|       id  = timestamp.to_i * 1000
 | |
|       id += rand(1000) if with_random
 | |
|       id  = id << 16
 | |
|       id += rand(2**16) if with_random
 | |
|       id
 | |
|     end
 | |
| 
 | |
|     def to_time(id)
 | |
|       Time.at((id >> 16) / 1000).utc
 | |
|     end
 | |
| 
 | |
|     private
 | |
| 
 | |
|     def already_defined?
 | |
|       connection.execute(<<~SQL.squish).values.first.first
 | |
|         SELECT EXISTS(
 | |
|           SELECT * FROM pg_proc WHERE proname = 'timestamp_id'
 | |
|         );
 | |
|       SQL
 | |
|     end
 | |
| 
 | |
|     def sanitized_timestamp_id_sql
 | |
|       ActiveRecord::Base.sanitize_sql_array(timestamp_id_sql_array)
 | |
|     end
 | |
| 
 | |
|     def timestamp_id_sql_array
 | |
|       [timestamp_id_sql_string, { random_string: SecureRandom.hex(16) }]
 | |
|     end
 | |
| 
 | |
|     def timestamp_id_sql_string
 | |
|       <<~SQL
 | |
|         CREATE OR REPLACE FUNCTION timestamp_id(table_name text)
 | |
|         RETURNS bigint AS
 | |
|         $$
 | |
|           DECLARE
 | |
|             time_part bigint;
 | |
|             sequence_base bigint;
 | |
|             tail bigint;
 | |
|           BEGIN
 | |
|             time_part := (
 | |
|               -- Get the time in milliseconds
 | |
|               ((date_part('epoch', now()) * 1000))::bigint
 | |
|               -- And shift it over two bytes
 | |
|               << 16);
 | |
| 
 | |
|             sequence_base := (
 | |
|               'x' ||
 | |
|               -- Take the first two bytes (four hex characters)
 | |
|               substr(
 | |
|                 -- Of the MD5 hash of the data we documented
 | |
|                 md5(table_name || :random_string || time_part::text),
 | |
|                 1, 4
 | |
|               )
 | |
|             -- And turn it into a bigint
 | |
|             )::bit(16)::bigint;
 | |
| 
 | |
|             -- Finally, add our sequence number to our base, and chop
 | |
|             -- it to the last two bytes
 | |
|             tail := (
 | |
|               (sequence_base + nextval(table_name || '_id_seq'))
 | |
|               & 65535);
 | |
| 
 | |
|             -- Return the time part and the sequence part. OR appears
 | |
|             -- faster here than addition, but they're equivalent:
 | |
|             -- time_part has no trailing two bytes, and tail is only
 | |
|             -- the last two bytes.
 | |
|             RETURN time_part | tail;
 | |
|           END
 | |
|         $$ LANGUAGE plpgsql VOLATILE;
 | |
|       SQL
 | |
|     end
 | |
| 
 | |
|     def connection
 | |
|       ActiveRecord::Base.connection
 | |
|     end
 | |
|   end
 | |
| end
 |