diff --git a/.idea/misc.xml b/.idea/misc.xml index f518f84..e659c40 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -54,7 +54,7 @@ - + diff --git a/metrik/merge.py b/metrik/merge.py index d351d04..6266ccb 100644 --- a/metrik/merge.py +++ b/metrik/merge.py @@ -9,9 +9,9 @@ def open_connection(host, port): return MongoClient(host=host, port=port) -def merge(con1, con2, database_name='metrik'): - database1 = con1[database_name] - database2 = con2[database_name] +def merge(con1, con2, db1, db2): + database1 = con1[db1] + database2 = con2[db2] collections = database1.collection_names(include_system_collections=False) for collection_name in collections: collection1 = database1[collection_name] @@ -37,14 +37,16 @@ def main(): help='The port number of the `left` database') parser.add_argument('-o', '--port-2', default=27017, dest='port2', type=int, help='The port number of the `right` database') - parser.add_argument('-d', '--database', default='metrik', - help='The database to merge from one host to the other') + parser.add_argument('-d', '--database-1', default='metrik', dest='db1', + help='The database on the `left` host we are merging from') + parser.add_argument('-s', '--database-2', default='metrik', dest='db2', + help='The database on the `right` host we are merging into') parser.add_argument('-v', '--version', action='version', version=__version__) args = parser.parse_args() con1 = open_connection(args.host1, args.port1) con2 = open_connection(args.host2, args.port2) - merge(con1, con2, args.database) + merge(con1, con2, args.db1, args.db2) con1.close() con2.close() diff --git a/test/test_merge.py b/test/test_merge.py new file mode 100644 index 0000000..83643a1 --- /dev/null +++ b/test/test_merge.py @@ -0,0 +1,35 @@ +import random, string + +from metrik.merge import merge, open_connection +from metrik.conf import get_config +from test.mongo_test import MongoTest + + +class MergeTest(MongoTest): + db2_name = 'metrik_test_2' + collection_name = 'merge_test' + + def setUp(self): + super(MergeTest, self).setUp() + self.client2 = self.client + self.db2 = self.client2[self.db2_name] + + def tearDown(self): + super(MergeTest, self).tearDown() + self.client2.drop_database(self.db2_name) + + def test_left_right_merge(self): + item_string = ''.join(random.choice(string.lowercase) for i in range(10)) + item = {'string': item_string} + item_id = self.db[self.collection_name].save(item) + + merge(self.client, self.client2, + self.db.name, self.db2.name) + + assert len(list(self.db[self.collection_name].find())) == 0 + assert len(list(self.db2[self.collection_name].find())) == 1 + + item_retrieved = self.db2[self.collection_name].find_one({'_id': item_id}) + assert item_retrieved is not None + assert item_retrieved['string'] == item_string +