s/in the url/in the xml On Sep 21, 2010, at 3:09 AM, Soren Hansen wrote: > Soren Hansen has proposed merging lp:~soren/nova/ec2-security-groups into lp:nova. > > Requested reviews: > Nova Core (nova-core) > > > This patch adds support for EC2 security groups using libvirt's nwfilter mechanism, which in turn uses iptables and ebtables on the individual compute nodes. > This has a number of benefits: > * Inter-VM network traffic can take the fastest route through the network without our having to worry about getting it through a central firewall. > * Not relying on a central firewall also removes a potential SPOF. > * The filtering load is distributed, offering great scalability. > > Caveats: > * It only works with libvirt and only with libvirt drivers that support nwfilter (qemu (and thus kvm) and uml, at the moment) > -- > https://code.launchpad.net/~soren/nova/ec2-security-groups/+merge/36119 > Your team Nova Core is requested to review the proposed merge of lp:~soren/nova/ec2-security-groups into lp:nova. > === modified file 'nova/auth/manager.py' > --- nova/auth/manager.py 2010-09-21 05:08:31 +0000 > +++ nova/auth/manager.py 2010-09-21 10:09:06 +0000 > @@ -490,6 +490,12 @@ > except: > drv.delete_project(project.id) > raise > + > + values = { 'name' : 'default', > + 'description' : 'default', > + 'user_id' : User.safe_id(manager_user), > + 'project_id' : project.id } > + db.security_group_create({}, values) > return project > > def modify_project(self, project, manager_user=None, description=None): > @@ -565,6 +571,16 @@ > except: > logging.exception('Could not destroy network for %s', > project) > + try: > + project_id = Project.safe_id(project) > + groups = db.security_group_get_by_project(context={}, > + project_id=project_id) > + for group in groups: > + db.security_group_destroy({}, group['id']) > + except: > + logging.exception('Could not destroy security groups for %s', > + project) > + > with self.driver() as drv: > drv.delete_project(Project.safe_id(project)) > > > === modified file 'nova/compute/manager.py' > --- nova/compute/manager.py 2010-09-13 09:15:02 +0000 > +++ nova/compute/manager.py 2010-09-21 10:09:06 +0000 > @@ -64,6 +64,11 @@ > > @defer.inlineCallbacks > @exception.wrap_exception > + def refresh_security_group(self, context, security_group_id, **_kwargs): > + yield self.driver.refresh_security_group(security_group_id) > + > + @defer.inlineCallbacks > + @exception.wrap_exception > def run_instance(self, context, instance_id, **_kwargs): > """Launch a new instance with specified options.""" > instance_ref = self.db.instance_get(context, instance_id) > > === modified file 'nova/db/api.py' > --- nova/db/api.py 2010-09-21 02:17:36 +0000 > +++ nova/db/api.py 2010-09-21 10:09:06 +0000 > @@ -296,6 +296,7 @@ > return IMPL.instance_update(context, instance_id, values) > > > +<<<<<<< TREE > ################### > > > @@ -324,6 +325,13 @@ > return IMPL.key_pair_get_all_by_user(context, user_id) > > > +======= > +def instance_add_security_group(context, instance_id, security_group_id): > + """Associate the given security group with the given instance""" > + return IMPL.instance_add_security_group(context, instance_id, security_group_id) > + > + > +>>>>>>> MERGE-SOURCE > #################### > > > @@ -539,3 +547,63 @@ > > """ > return IMPL.volume_update(context, volume_id, values) > + > + > +#################### > + > + > +def security_group_get_all(context): > + """Get all security groups""" > + return IMPL.security_group_get_all(context) > + > + > +def security_group_get(context, security_group_id): > + """Get security group by its internal id""" > + return IMPL.security_group_get(context, security_group_id) > + > + > +def security_group_get_by_name(context, project_id, group_name): > + """Returns a security group with the specified name from a project""" > + return IMPL.security_group_get_by_name(context, project_id, group_name) > + > + > +def security_group_get_by_project(context, project_id): > + """Get all security groups belonging to a project""" > + return IMPL.security_group_get_by_project(context, project_id) > + > + > +def security_group_get_by_instance(context, instance_id): > + """Get security groups to which the instance is assigned""" > + return IMPL.security_group_get_by_instance(context, instance_id) > + > + > +def securitygroup_exists(context, project_id, group_name): > + """Indicates if a group name exists in a project""" > + return IMPL.security_group_exists(context, project_id, group_name) > + > + > +def security_group_create(context, values): > + """Create a new security group""" > + return IMPL.security_group_create(context, values) > + > + > +def security_group_destroy(context, security_group_id): > + """Deletes a security group""" > + return IMPL.security_group_destroy(context, security_group_id) > + > + > +#################### > + > + > +def security_group_rule_create(context, values): > + """Create a new security group""" > + return IMPL.security_group_rule_create(context, values) > + > + > +def security_group_rule_get_by_security_group(context, security_group_id): > + """Get all rules for a a given security group""" > + return IMPL.security_group_rule_get_by_security_group(context, security_group_id) > + > +def security_group_rule_destroy(context, security_group_rule_id): > + """Deletes a security group rule""" > + return IMPL.security_group_rule_destroy(context, security_group_rule_id) > > === modified file 'nova/db/sqlalchemy/api.py' > --- nova/db/sqlalchemy/api.py 2010-09-21 02:17:36 +0000 > +++ nova/db/sqlalchemy/api.py 2010-09-21 10:09:06 +0000 > @@ -25,8 +25,12 @@ > from nova.db.sqlalchemy import models > from nova.db.sqlalchemy.session import get_session > from sqlalchemy import or_ > +<<<<<<< TREE > from sqlalchemy.orm import joinedload_all > from sqlalchemy.sql import func > +======= > +from sqlalchemy.orm import eagerload > +>>>>>>> MERGE-SOURCE > > FLAGS = flags.FLAGS > > @@ -378,6 +382,7 @@ > instance_ref.delete(session=session) > > > +<<<<<<< TREE > def instance_get(context, instance_id): > return models.Instance.find(instance_id, deleted=_deleted(context)) > > @@ -394,6 +399,22 @@ > session = get_session() > return session.query(models.Instance > ).options(joinedload_all('fixed_ip.floating_ips') > +======= > +def instance_get(_context, instance_id): > + session = get_session() > + return session.query(models.Instance > + ).options(eagerload('security_groups') > + ).get(instance_id) > + > + > +def instance_get_all(_context): > + return models.Instance.all() > + > + > +def instance_get_by_project(_context, project_id): > + session = get_session() > + return session.query(models.Instance > +>>>>>>> MERGE-SOURCE > ).filter_by(project_id=project_id > ).filter_by(deleted=_deleted(context) > ).all() > @@ -459,6 +480,17 @@ > instance_ref.save(session=session) > > > +def instance_add_security_group(context, instance_id, security_group_id): > + """Associate the given security group with the given instance""" > + session = get_session() > + with session.begin(): > + instance_ref = models.Instance.find(instance_id, session=session) > + security_group_ref = models.SecurityGroup.find(security_group_id, > + session=session) > + instance_ref.security_groups += [security_group_ref] > + instance_ref.save(session=session) > + > + > ################### > > > @@ -825,3 +857,106 @@ > for (key, value) in values.iteritems(): > volume_ref[key] = value > volume_ref.save(session=session) > + > + > +################### > + > + > +def security_group_get_all(_context): > + session = get_session() > + return session.query(models.SecurityGroup > + ).options(eagerload('rules') > + ).filter_by(deleted=False > + ).all() > + > + > +def security_group_get(_context, security_group_id): > + session = get_session() > + with session.begin(): > + return session.query(models.SecurityGroup > + ).options(eagerload('rules') > + ).get(security_group_id) > + > + > +def security_group_get_by_name(context, project_id, group_name): > + session = get_session() > + group_ref = session.query(models.SecurityGroup > + ).options(eagerload('rules') > + ).options(eagerload('instances') > + ).filter_by(project_id=project_id > + ).filter_by(name=group_name > + ).filter_by(deleted=False > + ).first() > + if not group_ref: > + raise exception.NotFound( > + 'No security group named %s for project: %s' \ > + % (group_name, project_id)) > + > + return group_ref > + > + > +def security_group_get_by_project(_context, project_id): > + session = get_session() > + return session.query(models.SecurityGroup > + ).options(eagerload('rules') > + ).filter_by(project_id=project_id > + ).filter_by(deleted=False > + ).all() > + > + > +def security_group_get_by_instance(_context, instance_id): > + session = get_session() > + with session.begin(): > + return session.query(models.Instance > + ).get(instance_id > + ).security_groups \ > + .filter_by(deleted=False > + ).all() > + > + > +def security_group_exists(_context, project_id, group_name): > + try: > + group = security_group_get_by_name(_context, project_id, group_name) > + return group != None > + except exception.NotFound: > + return False > + > + > +def security_group_create(_context, values): > + security_group_ref = models.SecurityGroup() > + # FIXME(devcamcar): Unless I do this, rules fails with lazy load exception > + # once save() is called. This will get cleaned up in next orm pass. > + security_group_ref.rules > + for (key, value) in values.iteritems(): > + security_group_ref[key] = value > + security_group_ref.save() > + return security_group_ref > + > + > +def security_group_destroy(_context, security_group_id): > + session = get_session() > + with session.begin(): > + # TODO(vish): do we have to use sql here? > + session.execute('update security_group set deleted=1 where id=:id', > + {'id': security_group_id}) > + session.execute('update security_group_rules set deleted=1 ' > + 'where group_id=:id', > + {'id': security_group_id}) > + > + > +################### > + > + > +def security_group_rule_create(_context, values): > + security_group_rule_ref = models.SecurityGroupIngressRule() > + for (key, value) in values.iteritems(): > + security_group_rule_ref[key] = value > + security_group_rule_ref.save() > + return security_group_rule_ref > + > +def security_group_rule_destroy(_context, security_group_rule_id): > + session = get_session() > + with session.begin(): > + security_group_rule = session.query(models.SecurityGroupIngressRule > + ).get(security_group_rule_id) > + security_group_rule.delete(session=session) > > === modified file 'nova/db/sqlalchemy/models.py' > --- nova/db/sqlalchemy/models.py 2010-09-21 02:17:36 +0000 > +++ nova/db/sqlalchemy/models.py 2010-09-21 10:09:06 +0000 > @@ -24,8 +24,14 @@ > import datetime > > # TODO(vish): clean up these imports > +<<<<<<< TREE > from sqlalchemy.orm import relationship, backref, exc, object_mapper > from sqlalchemy import Column, Integer, String > +======= > +from sqlalchemy.orm import relationship, backref, validates, exc > +from sqlalchemy.sql import func > +from sqlalchemy import Column, Integer, String, Table > +>>>>>>> MERGE-SOURCE > from sqlalchemy import ForeignKey, DateTime, Boolean, Text > from sqlalchemy.ext.declarative import declarative_base > > @@ -224,7 +230,6 @@ > launch_index = Column(Integer) > key_name = Column(String(255)) > key_data = Column(Text) > - security_group = Column(String(255)) > > state = Column(Integer) > state_description = Column(String(255)) > @@ -323,6 +328,7 @@ > uselist=False)) > > > +<<<<<<< TREE > class KeyPair(BASE, NovaBase): > """Represents a public key pair for ssh""" > __tablename__ = 'key_pairs' > @@ -359,6 +365,59 @@ > raise new_exc.__class__, new_exc, sys.exc_info()[2] > > > +======= > +security_group_instance_association = Table('security_group_instance_association', > + BASE.metadata, > + Column('security_group_id', Integer, > + ForeignKey('security_group.id')), > + Column('instance_id', Integer, > + ForeignKey('instances.id'))) > + > +class SecurityGroup(BASE, NovaBase): > + """Represents a security group""" > + __tablename__ = 'security_group' > + id = Column(Integer, primary_key=True) > + > + name = Column(String(255)) > + description = Column(String(255)) > + > + user_id = Column(String(255)) > + project_id = Column(String(255)) > + > + instances = relationship(Instance, > + secondary=security_group_instance_association, > + backref='security_groups') > + > + @property > + def user(self): > + return auth.manager.AuthManager().get_user(self.user_id) > + > + @property > + def project(self): > + return auth.manager.AuthManager().get_project(self.project_id) > + > + > +class SecurityGroupIngressRule(BASE, NovaBase): > + """Represents a rule in a security group""" > + __tablename__ = 'security_group_rules' > + id = Column(Integer, primary_key=True) > + > + parent_group_id = Column(Integer, ForeignKey('security_group.id')) > + parent_group = relationship("SecurityGroup", backref="rules", > + foreign_keys=parent_group_id, > + primaryjoin=parent_group_id==SecurityGroup.id) > + > + protocol = Column(String(5)) # "tcp", "udp", or "icmp" > + from_port = Column(Integer) > + to_port = Column(Integer) > + cidr = Column(String(255)) > + > + # Note: This is not the parent SecurityGroup. It's SecurityGroup we're > + # granting access for. > + group_id = Column(Integer, ForeignKey('security_group.id')) > + > + > +>>>>>>> MERGE-SOURCE > class Network(BASE, NovaBase): > """Represents a network""" > __tablename__ = 'networks' > @@ -462,8 +521,9 @@ > def register_models(): > """Register Models and create metadata""" > from sqlalchemy import create_engine > - models = (Service, Instance, Volume, ExportDevice, > - FixedIp, FloatingIp, Network, NetworkIndex) # , Image, Host) > + models = (Service, Instance, Volume, ExportDevice, FixedIp, FloatingIp, > + Network, NetworkIndex, SecurityGroup, SecurityGroupIngressRule) > + # , Image, Host > engine = create_engine(FLAGS.sql_connection, echo=False) > for model in models: > model.metadata.create_all(engine) > > === modified file 'nova/db/sqlalchemy/session.py' > --- nova/db/sqlalchemy/session.py 2010-09-08 02:48:12 +0000 > +++ nova/db/sqlalchemy/session.py 2010-09-21 10:09:06 +0000 > @@ -36,7 +36,8 @@ > if not _MAKER: > if not _ENGINE: > _ENGINE = create_engine(FLAGS.sql_connection, echo=False) > - _MAKER = sessionmaker(bind=_ENGINE, > - autocommit=autocommit, > - expire_on_commit=expire_on_commit) > - return _MAKER() > + _MAKER = (sessionmaker(bind=_ENGINE, > + autocommit=autocommit, > + expire_on_commit=expire_on_commit)) > + session = _MAKER() > + return session > > === modified file 'nova/endpoint/cloud.py' > --- nova/endpoint/cloud.py 2010-09-21 05:08:31 +0000 > +++ nova/endpoint/cloud.py 2010-09-21 10:09:06 +0000 > @@ -118,6 +118,14 @@ > result[key] = [line] > return result > > + def _trigger_refresh_security_group(self, security_group): > + nodes = set([instance.host for instance in security_group.instances]) > + for node in nodes: > + rpc.call('%s.%s' % (FLAGS.compute_topic, node), > + { "method": "refresh_security_group", > + "args": { "context": None, > + "security_group_id": security_group.id}}) > + > def get_metadata(self, address): > instance_ref = db.fixed_ip_get_instance(None, address) > if instance_ref is None: > @@ -252,18 +260,178 @@ > return True > > @rbac.allow('all') > - def describe_security_groups(self, context, group_names, **kwargs): > - groups = {'securityGroupSet': []} > - > - # Stubbed for now to unblock other things. > - return groups > - > - @rbac.allow('netadmin') > - def create_security_group(self, context, group_name, **kwargs): > - return True > + def describe_security_groups(self, context, group_name=None, **kwargs): > + if context.user.is_admin(): > + groups = db.security_group_get_all(context) > + else: > + groups = db.security_group_get_by_project(context, > + context.project.id) > + groups = [self._format_security_group(context, g) for g in groups] > + if not group_name is None: > + groups = [g for g in groups if g.name in group_name] > + > + return {'securityGroupInfo': groups } > + > + def _format_security_group(self, context, group): > + g = {} > + g['groupDescription'] = group.description > + g['groupName'] = group.name > + g['ownerId'] = context.user.id > + g['ipPermissions'] = [] > + for rule in group.rules: > + r = {} > + r['ipProtocol'] = rule.protocol > + r['fromPort'] = rule.from_port > + r['toPort'] = rule.to_port > + r['groups'] = [] > + r['ipRanges'] = [] > + if rule.group_id: > + source_group = db.security_group_get(context, rule.group_id) > + r['groups'] += [{'groupName': source_group.name, > + 'userId': source_group.user_id}] > + else: > + r['ipRanges'] += [{'cidrIp': rule.cidr}] > + g['ipPermissions'] += [r] > + return g > + > + > + @rbac.allow('netadmin') > + def revoke_security_group_ingress(self, context, group_name, > + to_port=None, from_port=None, > + ip_protocol=None, cidr_ip=None, > + user_id=None, > + source_security_group_name=None, > + source_security_group_owner_id=None): > + security_group = db.security_group_get_by_name(context, > + context.project.id, > + group_name) > + > + criteria = {} > + > + if source_security_group_name: > + source_project_id = self._get_source_project_id(context, > + source_security_group_owner_id) > + > + source_security_group = \ > + db.security_group_get_by_name(context, > + source_project_id, > + source_security_group_name) > + > + criteria['group_id'] = source_security_group > + elif cidr_ip: > + criteria['cidr'] = cidr_ip > + else: > + return { 'return': False } > + > + if ip_protocol and from_port and to_port: > + criteria['protocol'] = ip_protocol > + criteria['from_port'] = from_port > + criteria['to_port'] = to_port > + else: > + # If cidr based filtering, protocol and ports are mandatory > + if 'cidr' in criteria: > + return { 'return': False } > + > + for rule in security_group.rules: > + for (k,v) in criteria.iteritems(): > + if getattr(rule, k, False) != v: > + break > + # If we make it here, we have a match > + db.security_group_rule_destroy(context, rule.id) > + > + self._trigger_refresh_security_group(security_group) > + > + return True > + > + # TODO(soren): Lots and lots of input validation. We're accepting > + # strings here (such as ipProtocol), which are put into > + # filter rules verbatim. > + # TODO(soren): Dupe detection. Adding the same rule twice actually > + # adds the same rule twice to the rule set, which is > + # pointless. > + # TODO(soren): This has only been tested with Boto as the client. > + # Unfortunately, it seems Boto is using an old API > + # for these operations, so support for newer API versions > + # is sketchy. > + # TODO(soren): De-duplicate the turning method arguments into dict stuff. > + # revoke_security_group_ingress uses the exact same logic. > + @rbac.allow('netadmin') > + def authorize_security_group_ingress(self, context, group_name, > + to_port=None, from_port=None, > + ip_protocol=None, cidr_ip=None, > + source_security_group_name=None, > + source_security_group_owner_id=None): > + security_group = db.security_group_get_by_name(context, > + context.project.id, > + group_name) > + values = { 'parent_group_id' : security_group.id } > + > + if source_security_group_name: > + source_project_id = self._get_source_project_id(context, > + source_security_group_owner_id) > + > + source_security_group = \ > + db.security_group_get_by_name(context, > + source_project_id, > + source_security_group_name) > + values['group_id'] = source_security_group.id > + elif cidr_ip: > + values['cidr'] = cidr_ip > + else: > + return { 'return': False } > + > + if ip_protocol and from_port and to_port: > + values['protocol'] = ip_protocol > + values['from_port'] = from_port > + values['to_port'] = to_port > + else: > + # If cidr based filtering, protocol and ports are mandatory > + if 'cidr' in values: > + return None > + > + security_group_rule = db.security_group_rule_create(context, values) > + > + self._trigger_refresh_security_group(security_group) > + > + return True > + > + def _get_source_project_id(self, context, source_security_group_owner_id): > + if source_security_group_owner_id: > + # Parse user:project for source group. > + source_parts = source_security_group_owner_id.split(':') > + > + # If no project name specified, assume it's same as user name. > + # Since we're looking up by project name, the user name is not > + # used here. It's only read for EC2 API compatibility. > + if len(source_parts) == 2: > + source_project_id = source_parts[1] > + else: > + source_project_id = source_parts[0] > + else: > + source_project_id = context.project.id > + > + return source_project_id > + > + @rbac.allow('netadmin') > + def create_security_group(self, context, group_name, group_description): > + if db.securitygroup_exists(context, context.project.id, group_name): > + raise exception.ApiError('group %s already exists' % group_name) > + > + group = {'user_id' : context.user.id, > + 'project_id': context.project.id, > + 'name': group_name, > + 'description': group_description} > + group_ref = db.security_group_create(context, group) > + > + return {'securityGroupSet': [self._format_security_group(context, > + group_ref)]} > > @rbac.allow('netadmin') > def delete_security_group(self, context, group_name, **kwargs): > + security_group = db.security_group_get_by_name(context, > + context.project.id, > + group_name) > + db.security_group_destroy(context, security_group.id) > return True > > @rbac.allow('projectmanager', 'sysadmin') > @@ -601,8 +769,16 @@ > kwargs['key_name']) > key_data = key_pair_ref['public_key'] > > - # TODO: Get the real security group of launch in here > - security_group = "default" > + security_group_arg = kwargs.get('security_group', ["default"]) > + if not type(security_group_arg) is list: > + security_group_arg = [security_group_arg] > + > + security_groups = [] > + for security_group_name in security_group_arg: > + group = db.security_group_get_by_project(context, > + context.project.id, > + security_group_name) > + security_groups.append(group['id']) > > reservation_id = utils.generate_uid('r') > base_options = {} > @@ -616,6 +792,7 @@ > base_options['user_id'] = context.user.id > base_options['project_id'] = context.project.id > base_options['user_data'] = kwargs.get('user_data', '') > +<<<<<<< TREE > base_options['security_group'] = security_group > base_options['instance_type'] = instance_type > > @@ -627,6 +804,16 @@ > for num in range(num_instances): > instance_ref = db.instance_create(context, base_options) > inst_id = instance_ref['id'] > +======= > + base_options['instance_type'] = kwargs.get('instance_type', 'm1.small') > + > + for num in range(int(kwargs['max_count'])): > + inst_id = db.instance_create(context, base_options) > +>>>>>>> MERGE-SOURCE > + > + for security_group_id in security_groups: > + db.instance_add_security_group(context, inst_id, > + security_group_id) > > inst = {} > inst['mac_address'] = utils.generate_mac() > > === modified file 'nova/network/manager.py' > --- nova/network/manager.py 2010-09-11 15:24:19 +0000 > +++ nova/network/manager.py 2010-09-21 10:09:06 +0000 > @@ -201,7 +201,6 @@ > # in the datastore? > net = {} > net['injected'] = True > - net['network_str'] = FLAGS.flat_network_network > net['netmask'] = FLAGS.flat_network_netmask > net['bridge'] = FLAGS.flat_network_bridge > net['gateway'] = FLAGS.flat_network_gateway > > === modified file 'nova/process.py' > --- nova/process.py 2010-09-11 15:16:16 +0000 > +++ nova/process.py 2010-09-21 10:09:06 +0000 > @@ -113,7 +113,7 @@ > if self.started_deferred: > self.started_deferred.callback(self) > if self.process_input: > - self.transport.write(self.process_input) > + self.transport.write(str(self.process_input)) > self.transport.closeStdin() > > def get_process_output(executable, args=None, env=None, path=None, > > === modified file 'nova/tests/api_unittest.py' > --- nova/tests/api_unittest.py 2010-09-11 05:13:36 +0000 > +++ nova/tests/api_unittest.py 2010-09-21 10:09:06 +0000 > @@ -185,6 +185,9 @@ > self.host = '127.0.0.1' > > self.app = api.APIServerApplication({'Cloud': self.cloud}) > + > + def expect_http(self, host=None, is_secure=False): > + """Returns a new EC2 connection""" > self.ec2 = boto.connect_ec2( > aws_access_key_id='fake', > aws_secret_access_key='fake', > @@ -194,9 +197,6 @@ > path='/services/Cloud') > > self.mox.StubOutWithMock(self.ec2, 'new_http_connection') > - > - def expect_http(self, host=None, is_secure=False): > - """Returns a new EC2 connection""" > http = FakeHttplibConnection( > self.app, '%s:%d' % (self.host, FLAGS.cc_port), False) > # pylint: disable-msg=E1103 > @@ -232,3 +232,173 @@ > self.assertEquals(len(results), 1) > self.manager.delete_project(project) > self.manager.delete_user(user) > + > + def test_get_all_security_groups(self): > + """Test that we can retrieve security groups""" > + self.expect_http() > + self.mox.ReplayAll() > + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) > + project = self.manager.create_project('fake', 'fake', 'fake') > + > + rv = self.ec2.get_all_security_groups() > + > + self.assertEquals(len(rv), 1) > + self.assertEquals(rv[0].name, 'default') > + > + self.manager.delete_project(project) > + self.manager.delete_user(user) > + > + def test_create_delete_security_group(self): > + """Test that we can create a security group""" > + self.expect_http() > + self.mox.ReplayAll() > + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) > + project = self.manager.create_project('fake', 'fake', 'fake') > + > + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ > + for x in range(random.randint(4, 8))) > + > + self.ec2.create_security_group(security_group_name, 'test group') > + > + self.expect_http() > + self.mox.ReplayAll() > + > + rv = self.ec2.get_all_security_groups() > + self.assertEquals(len(rv), 2) > + self.assertTrue(security_group_name in [group.name for group in rv]) > + > + self.expect_http() > + self.mox.ReplayAll() > + > + self.ec2.delete_security_group(security_group_name) > + > + self.manager.delete_project(project) > + self.manager.delete_user(user) > + > + def test_authorize_revoke_security_group_cidr(self): > + """ > + Test that we can add and remove CIDR based rules > + to a security group > + """ > + self.expect_http() > + self.mox.ReplayAll() > + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) > + project = self.manager.create_project('fake', 'fake', 'fake') > + > + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ > + for x in range(random.randint(4, 8))) > + > + group = self.ec2.create_security_group(security_group_name, 'test group') > + > + self.expect_http() > + self.mox.ReplayAll() > + group.connection = self.ec2 > + > + group.authorize('tcp', 80, 81, '0.0.0.0/0') > + > + self.expect_http() > + self.mox.ReplayAll() > + > + rv = self.ec2.get_all_security_groups() > + # I don't bother checkng that we actually find it here, > + # because the create/delete unit test further up should > + # be good enough for that. > + for group in rv: > + if group.name == security_group_name: > + self.assertEquals(len(group.rules), 1) > + self.assertEquals(int(group.rules[0].from_port), 80) > + self.assertEquals(int(group.rules[0].to_port), 81) > + self.assertEquals(len(group.rules[0].grants), 1) > + self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0') > + > + self.expect_http() > + self.mox.ReplayAll() > + group.connection = self.ec2 > + > + group.revoke('tcp', 80, 81, '0.0.0.0/0') > + > + self.expect_http() > + self.mox.ReplayAll() > + > + self.ec2.delete_security_group(security_group_name) > + > + self.expect_http() > + self.mox.ReplayAll() > + group.connection = self.ec2 > + > + rv = self.ec2.get_all_security_groups() > + > + self.assertEqual(len(rv), 1) > + self.assertEqual(rv[0].name, 'default') > + > + self.manager.delete_project(project) > + self.manager.delete_user(user) > + > + return > + > + def test_authorize_revoke_security_group_foreign_group(self): > + """ > + Test that we can grant and revoke another security group access > + to a security group > + """ > + self.expect_http() > + self.mox.ReplayAll() > + user = self.manager.create_user('fake', 'fake', 'fake', admin=True) > + project = self.manager.create_project('fake', 'fake', 'fake') > + > + security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ > + for x in range(random.randint(4, 8))) > + other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \ > + for x in range(random.randint(4, 8))) > + > + group = self.ec2.create_security_group(security_group_name, 'test group') > + > + self.expect_http() > + self.mox.ReplayAll() > + > + other_group = self.ec2.create_security_group(other_security_group_name, > + 'some other group') > + > + self.expect_http() > + self.mox.ReplayAll() > + group.connection = self.ec2 > + > + group.authorize(src_group=other_group) > + > + self.expect_http() > + self.mox.ReplayAll() > + > + rv = self.ec2.get_all_security_groups() > + > + # I don't bother checkng that we actually find it here, > + # because the create/delete unit test further up should > + # be good enough for that. > + for group in rv: > + if group.name == security_group_name: > + self.assertEquals(len(group.rules), 1) > + self.assertEquals(len(group.rules[0].grants), 1) > + self.assertEquals(str(group.rules[0].grants[0]), > + '%s-%s' % (other_security_group_name, 'fake')) > + > + > + self.expect_http() > + self.mox.ReplayAll() > + > + rv = self.ec2.get_all_security_groups() > + > + for group in rv: > + if group.name == security_group_name: > + self.expect_http() > + self.mox.ReplayAll() > + group.connection = self.ec2 > + group.revoke(src_group=other_group) > + > + self.expect_http() > + self.mox.ReplayAll() > + > + self.ec2.delete_security_group(security_group_name) > + > + self.manager.delete_project(project) > + self.manager.delete_user(user) > + > + return > > === modified file 'nova/tests/virt_unittest.py' > --- nova/tests/virt_unittest.py 2010-08-13 21:46:44 +0000 > +++ nova/tests/virt_unittest.py 2010-09-21 10:09:06 +0000 > @@ -14,23 +14,31 @@ > # License for the specific language governing permissions and limitations > # under the License. > > +from xml.dom.minidom import parseString > + > +from nova import db > from nova import flags > from nova import test > +from nova.endpoint import cloud > from nova.virt import libvirt_conn > > FLAGS = flags.FLAGS > > > class LibvirtConnTestCase(test.TrialTestCase): > - def test_get_uri_and_template(self): > + def bitrot_test_get_uri_and_template(self): > class MockDataModel(object): > + def __getitem__(self, name): > + return self.datamodel[name] > + > def __init__(self): > self.datamodel = { 'name' : 'i-cafebabe', > 'memory_kb' : '1024000', > 'basepath' : '/some/path', > 'bridge_name' : 'br100', > 'mac_address' : '02:12:34:46:56:67', > - 'vcpus' : 2 } > + 'vcpus' : 2, > + 'project_id' : None } > > type_uri_map = { 'qemu' : ('qemu:///system', > [lambda s: '' in s, > @@ -53,7 +61,7 @@ > self.assertEquals(uri, expected_uri) > > for i, check in enumerate(checks): > - xml = conn.toXml(MockDataModel()) > + xml = conn.to_xml(MockDataModel()) > self.assertTrue(check(xml), '%s failed check %d' % (xml, i)) > > # Deliberately not just assigning this string to FLAGS.libvirt_uri and > @@ -67,3 +75,118 @@ > uri, template = conn.get_uri_and_template() > self.assertEquals(uri, testuri) > > + > +class NWFilterTestCase(test.TrialTestCase): > + def setUp(self): > + super(NWFilterTestCase, self).setUp() > + > + class Mock(object): > + pass > + > + self.context = Mock() > + self.context.user = Mock() > + self.context.user.id = 'fake' > + self.context.user.is_superuser = lambda:True > + self.context.project = Mock() > + self.context.project.id = 'fake' > + > + self.fake_libvirt_connection = Mock() > + > + self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection) > + > + def test_cidr_rule_nwfilter_xml(self): > + cloud_controller = cloud.CloudController() > + cloud_controller.create_security_group(self.context, > + 'testgroup', > + 'test group description') > + cloud_controller.authorize_security_group_ingress(self.context, > + 'testgroup', > + from_port='80', > + to_port='81', > + ip_protocol='tcp', > + cidr_ip='0.0.0.0/0') > + > + > + security_group = db.security_group_get_by_name({}, 'fake', 'testgroup') > + > + xml = self.fw.security_group_to_nwfilter_xml(security_group.id) > + > + dom = parseString(xml) > + self.assertEqual(dom.firstChild.tagName, 'filter') > + > + rules = dom.getElementsByTagName('rule') > + self.assertEqual(len(rules), 1) > + > + # It's supposed to allow inbound traffic. > + self.assertEqual(rules[0].getAttribute('action'), 'accept') > + self.assertEqual(rules[0].getAttribute('direction'), 'in') > + > + # Must be lower priority than the base filter (which blocks everything) > + self.assertTrue(int(rules[0].getAttribute('priority')) < 1000) > + > + ip_conditions = rules[0].getElementsByTagName('tcp') > + self.assertEqual(len(ip_conditions), 1) > + self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0/0') > + self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80') > + self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81') > + > + > + self.teardown_security_group() > + > + def teardown_security_group(self): > + cloud_controller = cloud.CloudController() > + cloud_controller.delete_security_group(self.context, 'testgroup') > + > + > + def setup_and_return_security_group(self): > + cloud_controller = cloud.CloudController() > + cloud_controller.create_security_group(self.context, > + 'testgroup', > + 'test group description') > + cloud_controller.authorize_security_group_ingress(self.context, > + 'testgroup', > + from_port='80', > + to_port='81', > + ip_protocol='tcp', > + cidr_ip='0.0.0.0/0') > + > + return db.security_group_get_by_name({}, 'fake', 'testgroup') > + > + def test_creates_base_rule_first(self): > + self.defined_filters = [] > + self.fake_libvirt_connection.listNWFilters = lambda:self.defined_filters > + self.base_filter_defined = False > + self.i = 0 > + def _filterDefineXMLMock(xml): > + dom = parseString(xml) > + name = dom.firstChild.getAttribute('name') > + if self.i == 0: > + self.assertEqual(dom.firstChild.getAttribute('name'), > + 'nova-base-filter') > + elif self.i == 1: > + self.assertTrue(name.startswith('nova-secgroup-'), > + 'unexpected name: %s' % name) > + elif self.i == 2: > + self.assertTrue(name.startswith('nova-instance-'), > + 'unexpected name: %s' % name) > + > + self.defined_filters.append(name) > + self.i += 1 > + return True > + > + def _ensure_all_called(_): > + self.assertEqual(self.i, 3) > + > + self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock > + > + inst_id = db.instance_create({}, { 'user_id' : 'fake', 'project_id' : 'fake' }) > + security_group = self.setup_and_return_security_group() > + > + db.instance_add_security_group({}, inst_id, security_group.id) > + instance = db.instance_get({}, inst_id) > + > + d = self.fw.setup_nwfilters_for_instance(instance) > + d.addCallback(_ensure_all_called) > + d.addCallback(lambda _:self.teardown_security_group()) > + > + return d > > === modified file 'nova/virt/interfaces.template' > --- nova/virt/interfaces.template 2010-08-13 21:45:26 +0000 > +++ nova/virt/interfaces.template 2010-09-21 10:09:06 +0000 > @@ -10,7 +10,6 @@ > iface eth0 inet static > address %(address)s > netmask %(netmask)s > - network %(network)s > broadcast %(broadcast)s > gateway %(gateway)s > dns-nameservers %(dns)s > > === modified file 'nova/virt/libvirt.qemu.xml.template' > --- nova/virt/libvirt.qemu.xml.template 2010-09-07 12:34:37 +0000 > +++ nova/virt/libvirt.qemu.xml.template 2010-09-21 10:09:06 +0000 > @@ -20,6 +20,9 @@ > > > > + > + > + > > > > > === modified file 'nova/virt/libvirt.uml.xml.template' > --- nova/virt/libvirt.uml.xml.template 2010-09-07 12:34:37 +0000 > +++ nova/virt/libvirt.uml.xml.template 2010-09-21 10:09:06 +0000 > @@ -14,6 +14,9 @@ > > > > + > + > + > > > > > === modified file 'nova/virt/libvirt_conn.py' > --- nova/virt/libvirt_conn.py 2010-09-09 15:55:09 +0000 > +++ nova/virt/libvirt_conn.py 2010-09-21 10:09:06 +0000 > @@ -27,6 +27,7 @@ > > from twisted.internet import defer > from twisted.internet import task > +from twisted.internet import threads > > from nova import db > from nova import exception > @@ -214,6 +215,7 @@ > instance['id'], > power_state.NOSTATE, > 'launching') > + yield NWFilterFirewall(self._conn).setup_nwfilters_for_instance(instance) > yield self._create_image(instance, xml) > yield self._conn.createXML(xml, 0) > # TODO(termie): this should actually register > @@ -285,7 +287,6 @@ > address = db.instance_get_fixed_address(None, inst['id']) > with open(FLAGS.injected_network_template) as f: > net = f.read() % {'address': address, > - 'network': network_ref['network'], > 'netmask': network_ref['netmask'], > 'gateway': network_ref['gateway'], > 'broadcast': network_ref['broadcast'], > @@ -317,6 +318,7 @@ > network = db.project_get_network(None, instance['project_id']) > # FIXME(vish): stick this in db > instance_type = instance_types.INSTANCE_TYPES[instance['instance_type']] > + ip_address = db.instance_get_fixed_address({}, instance['id']) > xml_info = {'type': FLAGS.libvirt_type, > 'name': instance['name'], > 'basepath': os.path.join(FLAGS.instances_path, > @@ -324,7 +326,8 @@ > 'memory_kb': instance_type['memory_mb'] * 1024, > 'vcpus': instance_type['vcpus'], > 'bridge_name': network['bridge'], > - 'mac_address': instance['mac_address']} > + 'mac_address': instance['mac_address'], > + 'ip_address': ip_address } > libvirt_xml = self.libvirt_xml % xml_info > logging.debug('instance %s: finished toXML method', instance['name']) > > @@ -438,3 +441,155 @@ > """ > domain = self._conn.lookupByName(instance_name) > return domain.interfaceStats(interface) > + > + > + def refresh_security_group(self, security_group_id): > + fw = NWFilterFirewall(self._conn) > + fw.ensure_security_group_filter(security_group_id, override=True) > + > + > +class NWFilterFirewall(object): > + """ > + This class implements a network filtering mechanism versatile > + enough for EC2 style Security Group filtering by leveraging > + libvirt's nwfilter. > + > + First, all instances get a filter ("nova-base-filter") applied. > + This filter drops all incoming ipv4 and ipv6 connections. > + Outgoing connections are never blocked. > + > + Second, every security group maps to a nwfilter filter(*). > + NWFilters can be updated at runtime and changes are applied > + immediately, so changes to security groups can be applied at > + runtime (as mandated by the spec). > + > + Security group rules are named "nova-secgroup-" where > + is the internal id of the security group. They're applied only on > + hosts that have instances in the security group in question. > + > + Updates to security groups are done by updating the data model > + (in response to API calls) followed by a request sent to all > + the nodes with instances in the security group to refresh the > + security group. > + > + Each instance has its own NWFilter, which references the above > + mentioned security group NWFilters. This was done because > + interfaces can only reference one filter while filters can > + reference multiple other filters. This has the added benefit of > + actually being able to add and remove security groups from an > + instance at run time. This functionality is not exposed anywhere, > + though. > + > + Outstanding questions: > + > + The name is unique, so would there be any good reason to sync > + the uuid across the nodes (by assigning it from the datamodel)? > + > + > + (*) This sentence brought to you by the redundancy department of > + redundancy. > + """ > + > + def __init__(self, get_connection): > + self._conn = get_connection > + > + > + def nova_base_filter(self): > + return ''' > + 26717364-50cf-42d1-8185-29bf893ab110 > + > + > + > + > + > + > + > +''' > + > + > + def setup_nwfilters_for_instance(self, instance): > + """ > + Creates an NWFilter for the given instance. In the process, > + it makes sure the filters for the security groups as well as > + the base filter are all in place. > + """ > + > + d = self.ensure_base_filter() > + > + nwfilter_xml = ("\n" + > + " \n" > + ) % instance['name'] > + > + for security_group in instance.security_groups: > + d.addCallback(lambda _:self.ensure_security_group_filter(security_group.id)) > + > + nwfilter_xml += (" \n" > + ) % security_group.id > + nwfilter_xml += "" > + > + d.addCallback(lambda _: threads.deferToThread( > + self._conn.nwfilterDefineXML, > + nwfilter_xml)) > + return d > + > + > + def _nwfilter_name_for_security_group(self, security_group_id): > + return 'nova-secgroup-%d' % (security_group_id,) > + > + > + # TODO(soren): Should override be the default (and should it even > + # be optional? We save a bit of processing time in > + # libvirt by only defining this conditionally, but > + # we still have to go and ask libvirt if the group > + # is already defined, and there's the off chance of > + # of inconsitencies having snuck in which would get > + # fixed by just redefining the filter. > + def define_filter(self, name, xml_generator, override=False): > + if not override: > + def _already_exists_check(filterlist, filter): > + return filter in filterlist > + d = threads.deferToThread(self._conn.listNWFilters) > + d.addCallback(_already_exists_check, name) > + else: > + # Pretend we looked it up and it wasn't defined > + d = defer.succeed(False) > + def _define_if_not_exists(exists, xml_generator): > + if not exists: > + xml = xml_generator() > + return threads.deferToThread(self._conn.nwfilterDefineXML, xml) > + d.addCallback(_define_if_not_exists, xml_generator) > + return d > + > + > + def ensure_base_filter(self): > + return self.define_filter('nova-base-filter', self.nova_base_filter) > + > + > + def ensure_security_group_filter(self, security_group_id, override=False): > + return self.define_filter( > + self._nwfilter_name_for_security_group(security_group_id), > + lambda:self.security_group_to_nwfilter_xml(security_group_id), > + override=override) > + > + > + def security_group_to_nwfilter_xml(self, security_group_id): > + security_group = db.security_group_get({}, security_group_id) > + rule_xml = "" > + for rule in security_group.rules: > + rule_xml += "" > + if rule.cidr: > + rule_xml += "<%s srcipaddr='%s' " % (rule.protocol, rule.cidr) > + if rule.protocol in ['tcp', 'udp']: > + rule_xml += "dstportstart='%s' dstportend='%s' " % \ > + (rule.from_port, rule.to_port) > + elif rule.protocol == 'icmp': > + logging.info('rule.protocol: %r, rule.from_port: %r, rule.to_port: %r' % (rule.protocol, rule.from_port, rule.to_port)) > + if rule.from_port != -1: > + rule_xml += "type='%s' " % rule.from_port > + if rule.to_port != -1: > + rule_xml += "code='%s' " % rule.to_port > + > + rule_xml += '/>\n' > + rule_xml += "\n" > + xml = '''%s''' % (security_group_id, rule_xml,) > + return xml > > === modified file 'run_tests.py' > --- run_tests.py 2010-09-20 21:38:38 +0000 > +++ run_tests.py 2010-09-21 10:09:06 +0000 > @@ -64,6 +64,7 @@ > from nova.tests.service_unittest import * > from nova.tests.validator_unittest import * > from nova.tests.volume_unittest import * > +from nova.tests.virt_unittest import * > > > FLAGS = flags.FLAGS >