import unittest

from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.context_managers import (switch_db, switch_collection,
                                          no_sub_classes, no_dereference,
                                          query_counter)
from mongoengine.pymongo_support import count_documents


class ContextManagersTest(unittest.TestCase):

    def test_switch_db_context_manager(self):
        connect('mongoenginetest')
        register_connection('testdb-1', 'mongoenginetest2')

        class Group(Document):
            name = StringField()

        Group.drop_collection()

        Group(name="hello - default").save()
        self.assertEqual(1, Group.objects.count())

        with switch_db(Group, 'testdb-1') as Group:

            self.assertEqual(0, Group.objects.count())

            Group(name="hello").save()

            self.assertEqual(1, Group.objects.count())

            Group.drop_collection()
            self.assertEqual(0, Group.objects.count())

        self.assertEqual(1, Group.objects.count())

    def test_switch_collection_context_manager(self):
        connect('mongoenginetest')
        register_connection(alias='testdb-1', db='mongoenginetest2')

        class Group(Document):
            name = StringField()

        Group.drop_collection()         # drops in default

        with switch_collection(Group, 'group1') as Group:
            Group.drop_collection()     # drops in group1

        Group(name="hello - group").save()
        self.assertEqual(1, Group.objects.count())

        with switch_collection(Group, 'group1') as Group:

            self.assertEqual(0, Group.objects.count())

            Group(name="hello - group1").save()

            self.assertEqual(1, Group.objects.count())

            Group.drop_collection()
            self.assertEqual(0, Group.objects.count())

        self.assertEqual(1, Group.objects.count())

    def test_no_dereference_context_manager_object_id(self):
        """Ensure that DBRef items in ListFields aren't dereferenced.
        """
        connect('mongoenginetest')

        class User(Document):
            name = StringField()

        class Group(Document):
            ref = ReferenceField(User, dbref=False)
            generic = GenericReferenceField()
            members = ListField(ReferenceField(User, dbref=False))

        User.drop_collection()
        Group.drop_collection()

        for i in range(1, 51):
            User(name='user %s' % i).save()

        user = User.objects.first()
        Group(ref=user, members=User.objects, generic=user).save()

        with no_dereference(Group) as NoDeRefGroup:
            self.assertTrue(Group._fields['members']._auto_dereference)
            self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)

        with no_dereference(Group) as Group:
            group = Group.objects.first()
            for m in group.members:
                self.assertNotIsInstance(m, User)
            self.assertNotIsInstance(group.ref, User)
            self.assertNotIsInstance(group.generic, User)

        for m in group.members:
            self.assertIsInstance(m, User)
        self.assertIsInstance(group.ref, User)
        self.assertIsInstance(group.generic, User)

    def test_no_dereference_context_manager_dbref(self):
        """Ensure that DBRef items in ListFields aren't dereferenced.
        """
        connect('mongoenginetest')

        class User(Document):
            name = StringField()

        class Group(Document):
            ref = ReferenceField(User, dbref=True)
            generic = GenericReferenceField()
            members = ListField(ReferenceField(User, dbref=True))

        User.drop_collection()
        Group.drop_collection()

        for i in range(1, 51):
            User(name='user %s' % i).save()

        user = User.objects.first()
        Group(ref=user, members=User.objects, generic=user).save()

        with no_dereference(Group) as NoDeRefGroup:
            self.assertTrue(Group._fields['members']._auto_dereference)
            self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)

        with no_dereference(Group) as Group:
            group = Group.objects.first()
            self.assertTrue(all([not isinstance(m, User)
                                for m in group.members]))
            self.assertNotIsInstance(group.ref, User)
            self.assertNotIsInstance(group.generic, User)

        self.assertTrue(all([isinstance(m, User)
                             for m in group.members]))
        self.assertIsInstance(group.ref, User)
        self.assertIsInstance(group.generic, User)

    def test_no_sub_classes(self):
        class A(Document):
            x = IntField()
            meta = {'allow_inheritance': True}

        class B(A):
            z = IntField()

        class C(B):
            zz = IntField()

        A.drop_collection()

        A(x=10).save()
        A(x=15).save()
        B(x=20).save()
        B(x=30).save()
        C(x=40).save()

        self.assertEqual(A.objects.count(), 5)
        self.assertEqual(B.objects.count(), 3)
        self.assertEqual(C.objects.count(), 1)

        with no_sub_classes(A):
            self.assertEqual(A.objects.count(), 2)

            for obj in A.objects:
                self.assertEqual(obj.__class__, A)

        with no_sub_classes(B):
            self.assertEqual(B.objects.count(), 2)

            for obj in B.objects:
                self.assertEqual(obj.__class__, B)

        with no_sub_classes(C):
            self.assertEqual(C.objects.count(), 1)

            for obj in C.objects:
                self.assertEqual(obj.__class__, C)

        # Confirm context manager exit correctly
        self.assertEqual(A.objects.count(), 5)
        self.assertEqual(B.objects.count(), 3)
        self.assertEqual(C.objects.count(), 1)

    def test_no_sub_classes_modification_to_document_class_are_temporary(self):
        class A(Document):
            x = IntField()
            meta = {'allow_inheritance': True}

        class B(A):
            z = IntField()

        self.assertEqual(A._subclasses, ('A', 'A.B'))
        with no_sub_classes(A):
            self.assertEqual(A._subclasses, ('A',))
        self.assertEqual(A._subclasses, ('A', 'A.B'))

        self.assertEqual(B._subclasses, ('A.B',))
        with no_sub_classes(B):
            self.assertEqual(B._subclasses, ('A.B',))
        self.assertEqual(B._subclasses, ('A.B',))

    def test_no_subclass_context_manager_does_not_swallow_exception(self):
        class User(Document):
            name = StringField()

        with self.assertRaises(TypeError):
            with no_sub_classes(User):
                raise TypeError()

    def test_query_counter_does_not_swallow_exception(self):

        with self.assertRaises(TypeError):
            with query_counter() as q:
                raise TypeError()

    def test_query_counter_temporarily_modifies_profiling_level(self):
        connect('mongoenginetest')
        db = get_db()

        initial_profiling_level = db.profiling_level()

        try:
            NEW_LEVEL = 1
            db.set_profiling_level(NEW_LEVEL)
            self.assertEqual(db.profiling_level(), NEW_LEVEL)
            with query_counter() as q:
                self.assertEqual(db.profiling_level(), 2)
            self.assertEqual(db.profiling_level(), NEW_LEVEL)
        except Exception:
            db.set_profiling_level(initial_profiling_level)    # Ensures it gets reseted no matter the outcome of the test
            raise

    def test_query_counter(self):
        connect('mongoenginetest')
        db = get_db()

        collection = db.query_counter
        collection.drop()

        def issue_1_count_query():
            count_documents(collection, {})

        def issue_1_insert_query():
            collection.insert_one({'test': 'garbage'})

        def issue_1_find_query():
            collection.find_one()

        counter = 0
        with query_counter() as q:
            self.assertEqual(q, counter)
            self.assertEqual(q, counter)    # Ensures previous count query did not get counted

            for _ in range(10):
                issue_1_insert_query()
                counter += 1
            self.assertEqual(q, counter)

            for _ in range(4):
                issue_1_find_query()
                counter += 1
            self.assertEqual(q, counter)

            for _ in range(3):
                issue_1_count_query()
                counter += 1
            self.assertEqual(q, counter)

            self.assertEqual(int(q), counter)       # test __int__
            self.assertEqual(repr(q), str(int(q)))  # test __repr__
            self.assertGreater(q, -1)               # test __gt__
            self.assertGreaterEqual(q, int(q))      # test __gte__
            self.assertNotEqual(q, -1)
            self.assertLess(q, 1000)
            self.assertLessEqual(q, int(q))

    def test_query_counter_counts_getmore_queries(self):
        connect('mongoenginetest')
        db = get_db()

        collection = db.query_counter
        collection.drop()

        many_docs = [{'test': 'garbage %s' % i} for i in range(150)]
        collection.insert_many(many_docs)   # first batch of documents contains 101 documents

        with query_counter() as q:
            self.assertEqual(q, 0)
            list(collection.find())
            self.assertEqual(q, 2)  # 1st select + 1 getmore

    def test_query_counter_ignores_particular_queries(self):
        connect('mongoenginetest')
        db = get_db()

        collection = db.query_counter
        collection.insert_many([{'test': 'garbage %s' % i} for i in range(10)])

        with query_counter() as q:
            self.assertEqual(q, 0)
            cursor = collection.find()
            self.assertEqual(q, 0)      # cursor wasn't opened yet
            _ = next(cursor)            # opens the cursor and fires the find query
            self.assertEqual(q, 1)

            cursor.close()              # issues a `killcursors` query that is ignored by the context
            self.assertEqual(q, 1)
            _ = db.system.indexes.find_one()    # queries on db.system.indexes are ignored as well
            self.assertEqual(q, 1)


if __name__ == '__main__':
    unittest.main()
